pytorch
4c78c7c8 - Enable `src_mask` in fast path of `TransformerEncoderLayer ` (#87377)

Commit
2 years ago
Enable `src_mask` in fast path of `TransformerEncoderLayer ` (#87377) ## Issues Fixes https://github.com/pytorch/pytorch/issues/81129#issuecomment-1179435674 ## Description Passing a 2D attention mask `src_mask` into the fast path of `TransformerEncoderLayer` in CPU was causing an error and so was disabled in https://github.com/pytorch/pytorch/pull/81277. This PR unrolls this fix, enabling `src_mask` on the fast path: - Either attention mask `src_mask` of shape `(L, L)` or padding mask `src_key_padding_mask` of shape `(B, L)` are now allowed on the CPU fast path. If softmax is applied along the last dimension (as in multi-head attention), these masks are processed without expanding them to 4D. Instead, when iterating through the input, `Softmax.cpp::host_softmax` converts the index to match the mask dimensions, depending on the type. - If softmax is applied along the dimension other than the last, `Softmax.cpp::masked_softmax_cpu` expands masks to 4D, converting them to `mask_type=2`. Theoretically one could also add special optimized cases for `dim=0, 1, 2` and process them without mask expansion, but I don't know how often is that used ## Tests: - `test_transformerencoderlayer_fast_path` is extended to cover both attention mask and padding mask - `test_masked_softmax_mask_types_0_1` is added to ensure results from CPU softmax with attention and padding masks match the explicit slow calculation - `test_masked_softmax_devices_parity` is added to ensure results from masked softmax on CPU and CUDA match ## Note I had to replace `float` with `torch.get_default_dtype()` in a couple of tests for the following reason: - `test_nn.py` [sets the default type to `torch.double`](https://github.com/pytorch/pytorch/blob/master/test/test_nn.py#L24-L26) - If I execute `test_nn.py` and `test_transformers.py` in one `pytest` run, this default still holds for transformer tests - Some tests in `test_transformers.py` which were previously following the slow path now switched to fast path, and hard-coded `float` started clashing with default `double` Let me know if there is a better way around it - or maybe I'm not supposed to run tests with `pytest` like this Pull Request resolved: https://github.com/pytorch/pytorch/pull/87377 Approved by: https://github.com/mikekgfb, https://github.com/weiwangmeta, https://github.com/malfet
Author
Committer
Parents
Loading