pytorch
ae819812 - [PyTorch] Handle non-vectorizable parameters for native MHA CUDA rescale kernel (#72671)

Commit
2 years ago
[PyTorch] Handle non-vectorizable parameters for native MHA CUDA rescale kernel (#72671) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72671 The existing kernel did not handle cases where D % 4 != 0 or dim_per_head % 4 != 0. Now we have a non-vectorized kernel for these cases. ghstack-source-id: 149201477 Test Plan: Updated test_nn to cover these cases. Reviewed By: zrphercule, ngimel Differential Revision: D34119371 fbshipit-source-id: 4e9b4d9b636224ef2c433593f6f236df040de782 (cherry picked from commit f5393878e4c16342ee62465bb656b18053000677)
Author
Committer
Parents
Loading