[PyTorch] Handle non-vectorizable parameters for native MHA CUDA rescale kernel (#72671)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72671
The existing kernel did not handle cases where D % 4 != 0 or dim_per_head % 4 != 0. Now we have a non-vectorized kernel for these cases.
ghstack-source-id: 149201477
Test Plan: Updated test_nn to cover these cases.
Reviewed By: zrphercule, ngimel
Differential Revision: D34119371
fbshipit-source-id: 4e9b4d9b636224ef2c433593f6f236df040de782
(cherry picked from commit f5393878e4c16342ee62465bb656b18053000677)