[PyTorch] Existing MHA: fuse the attn_mask addition (#73219)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73219
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse?
ghstack-source-id: 152549975
Test Plan:
CI, review
Ran benchmark_transformers.par mha --batch-size 64 --max-sequence-length 128 --avg-sequence-length 256 --large --use-real-data-distribution --use-mask
and looked at the PT time number
```
before:
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms, PT FLOPS: 59.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.46TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.23ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 59.57TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.75TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 58.87TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.77TFLOP/s
after:
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms, PT FLOPS: 60.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.51TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 59.80TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.69TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.21ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 60.21TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.86TFLOP/s
```
Inspected a Kineto trace and confirmed that an elementwise add was fused into baddbmm.
Additional opportunity: I see a copy_ inside baddbmm that wasn't happening with the bmm path and I'm not sure why. Perhaps something went wrong with the structured kernels port by ezyang?
Reviewed By: ezyang
Differential Revision: D34160547
fbshipit-source-id: 78d406fb035e6f3bf13af2c9443a886eada35ac4
(cherry picked from commit aaffc39b24058742cb9ae42105f95b3eafe9d7f5)