MHA optimizations (#93234)
Slight perf optimizations for regular MHA by reducing the number of kernels called
Before:
![image](https://user-images.githubusercontent.com/30204471/215349212-172c6364-9e3c-4fd1-92b6-8ddd9931613e.png)
After:
![image](https://user-images.githubusercontent.com/30204471/215349247-021dd9e6-f6ca-40a2-8de8-0805af001f69.png)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93234
Approved by: https://github.com/drisspg