[PyTorch] Don't enter MHA fast path when bias & query dtypes don't match
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76879
The fast path does not support this: transform_bias_rescale_qkv will try to grab bias.data_ptr() assuming the dtypes are the same. (Also, I have no idea how this happens.)
Differential Revision: [D36156872](https://our.internmc.facebook.com/intern/diff/D36156872/)
Approved by: https://github.com/cpuhrsch