Update assertion in MHA forward to support FP16 training (#37539)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37539
Bug fix
Test Plan:
This passed fbtranslate local integration test when I toggle fp16 to true on GPU.
Also it passed in with D21312488
Reviewed By: zhangguanheng66
Differential Revision: D21311505
fbshipit-source-id: 7ebd7375ef2c1b2ba4ac6fe7be5e7be1a490a319