pytorch
87f40ee6 - [PyTorch] Existing MHA: fuse the attn_mask addition (#73219)

Commit
2 years ago
[PyTorch] Existing MHA: fuse the attn_mask addition (#73219) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73219 Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? ghstack-source-id: 152549975 Test Plan: CI, review Ran benchmark_transformers.par mha --batch-size 64 --max-sequence-length 128 --avg-sequence-length 256 --large --use-real-data-distribution --use-mask and looked at the PT time number ``` before: B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms, PT FLOPS: 59.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.46TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.23ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 59.57TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.75TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 58.87TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.77TFLOP/s after: B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms, PT FLOPS: 60.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.51TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 59.80TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.69TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.21ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 60.21TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.86TFLOP/s ``` Inspected a Kineto trace and confirmed that an elementwise add was fused into baddbmm. Additional opportunity: I see a copy_ inside baddbmm that wasn't happening with the bmm path and I'm not sure why. Perhaps something went wrong with the structured kernels port by ezyang? Reviewed By: ezyang Differential Revision: D34160547 fbshipit-source-id: 78d406fb035e6f3bf13af2c9443a886eada35ac4 (cherry picked from commit aaffc39b24058742cb9ae42105f95b3eafe9d7f5)
Author
Committer
Parents
Loading