[nn] mha : no-batch-dim support (python) (#67176)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585
* [x] Update docs
* [x] Tests for shape checking
Tests take roughly 20s on system that I use. Below is the timings for slowest 20 tests.
```
pytest test/test_modules.py -k _multih --durations=20
============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.0, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
rootdir: /home/kshiteej/Pytorch/pytorch_no_batch_mha, configfile: pytest.ini
plugins: hypothesis-6.23.2, repeat-0.9.1
collected 372 items / 336 deselected / 36 selected
test/test_modules.py ..............ssssssss.............. [100%]
================================================================================================ warnings summary ================================================================================================
../../.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:73
test/test_modules.py::TestModuleCUDA::test_factory_kwargs_nn_MultiheadAttention_cuda_float32
/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:73: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.
warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/warnings.html
============================================================================================== slowest 20 durations ==============================================================================================
8.66s call test/test_modules.py::TestModuleCUDA::test_gradgrad_nn_MultiheadAttention_cuda_float64
2.02s call test/test_modules.py::TestModuleCPU::test_gradgrad_nn_MultiheadAttention_cpu_float64
1.89s call test/test_modules.py::TestModuleCUDA::test_grad_nn_MultiheadAttention_cuda_float64
1.01s call test/test_modules.py::TestModuleCUDA::test_factory_kwargs_nn_MultiheadAttention_cuda_float32
0.51s call test/test_modules.py::TestModuleCPU::test_grad_nn_MultiheadAttention_cpu_float64
0.46s call test/test_modules.py::TestModuleCUDA::test_forward_nn_MultiheadAttention_cuda_float32
0.45s call test/test_modules.py::TestModuleCUDA::test_non_contiguous_tensors_nn_MultiheadAttention_cuda_float64
0.44s call test/test_modules.py::TestModuleCUDA::test_non_contiguous_tensors_nn_MultiheadAttention_cuda_float32
0.21s call test/test_modules.py::TestModuleCUDA::test_pickle_nn_MultiheadAttention_cuda_float64
0.21s call test/test_modules.py::TestModuleCUDA::test_pickle_nn_MultiheadAttention_cuda_float32
0.18s call test/test_modules.py::TestModuleCUDA::test_forward_nn_MultiheadAttention_cuda_float64
0.17s call test/test_modules.py::TestModuleCPU::test_non_contiguous_tensors_nn_MultiheadAttention_cpu_float32
0.16s call test/test_modules.py::TestModuleCPU::test_non_contiguous_tensors_nn_MultiheadAttention_cpu_float64
0.11s call test/test_modules.py::TestModuleCUDA::test_factory_kwargs_nn_MultiheadAttention_cuda_float64
0.08s call test/test_modules.py::TestModuleCPU::test_pickle_nn_MultiheadAttention_cpu_float32
0.08s call test/test_modules.py::TestModuleCPU::test_pickle_nn_MultiheadAttention_cpu_float64
0.06s call test/test_modules.py::TestModuleCUDA::test_repr_nn_MultiheadAttention_cuda_float64
0.06s call test/test_modules.py::TestModuleCUDA::test_repr_nn_MultiheadAttention_cuda_float32
0.06s call test/test_modules.py::TestModuleCPU::test_forward_nn_MultiheadAttention_cpu_float32
0.06s call test/test_modules.py::TestModuleCPU::test_forward_nn_MultiheadAttention_cpu_float64
============================================================================================ short test summary info =============================================================================================
=========================================================================== 28 passed, 8 skipped, 336 deselected, 2 warnings in 19.71s ===========================================================================
```
cc albanD mruberry jbschlosser walterddr
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67176
Reviewed By: dagitses
Differential Revision: D33094285
Pulled By: jbschlosser
fbshipit-source-id: 0dd08261b8a457bf8bad5c7f3f6ded14b0beaf0d