[PyTorch] MHA: add debug shape checks (#72457)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72457
No cost to adding this basic self-check on our operations.
ghstack-source-id: 149067335
Test Plan: CI
Reviewed By: zrphercule
Differential Revision: D33954672
fbshipit-source-id: f57b3c2463db403431f884db56063cec2ca93ef2
(cherry picked from commit 3fb79f3cf8e2de8cfc4d9efc1cec0f9ed7a7ecd3)