Enable `src_mask` in fast path of `TransformerEncoderLayer ` (#87377)
## Issues
Fixes https://github.com/pytorch/pytorch/issues/81129#issuecomment-1179435674
## Description
Passing a 2D attention mask `src_mask` into the fast path of `TransformerEncoderLayer` in CPU was causing an error and so was disabled in https://github.com/pytorch/pytorch/pull/81277. This PR unrolls this fix, enabling `src_mask` on the fast path:
- Either attention mask `src_mask` of shape `(L, L)` or padding mask `src_key_padding_mask` of shape `(B, L)` are now allowed on the CPU fast path. If softmax is applied along the last dimension (as in multi-head attention), these masks are processed without expanding them to 4D. Instead, when iterating through the input, `Softmax.cpp::host_softmax` converts the index to match the mask dimensions, depending on the type.
- If softmax is applied along the dimension other than the last, `Softmax.cpp::masked_softmax_cpu` expands masks to 4D, converting them to `mask_type=2`. Theoretically one could also add special optimized cases for `dim=0, 1, 2` and process them without mask expansion, but I don't know how often is that used
## Tests:
- `test_transformerencoderlayer_fast_path` is extended to cover both attention mask and padding mask
- `test_masked_softmax_mask_types_0_1` is added to ensure results from CPU softmax with attention and padding masks match the explicit slow calculation
- `test_masked_softmax_devices_parity` is added to ensure results from masked softmax on CPU and CUDA match
## Note
I had to replace `float` with `torch.get_default_dtype()` in a couple of tests for the following reason:
- `test_nn.py` [sets the default type to `torch.double`](https://github.com/pytorch/pytorch/blob/master/test/test_nn.py#L24-L26)
- If I execute `test_nn.py` and `test_transformers.py` in one `pytest` run, this default still holds for transformer tests
- Some tests in `test_transformers.py` which were previously following the slow path now switched to fast path, and hard-coded `float` started clashing with default `double`
Let me know if there is a better way around it - or maybe I'm not supposed to run tests with `pytest` like this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87377
Approved by: https://github.com/mikekgfb, https://github.com/weiwangmeta, https://github.com/malfet