Disable nn.MHA fastpath for floating point masks (#107641)
Fixes https://github.com/pytorch/pytorch/issues/107084 by disabling the fast path when floating point masks (which should be additive) are passed
- [We claim in our docs for MHA that float masks will be added to the attention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (be it `key_padding_mask` or `attn_mask`)
- We always canonicalize any mask at the start of MHA in python by converting it to float
- my understanding from Driss is that SDPA properly supports additive masking (but there are many special cases for mask shape for MHA that don't work properly currently (BxT, TxT) so [we're turning this off for now](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L531-L532)
- More broadly, the problem isn't with the SDPA path, but that things are broken for the path it falls back to
- Right now mha "fast path" code with non-None masks is always going through [this path ](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L554-L640) that has a call to `masked_softmax` that [converts the masks back to bool](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L154-L156)
- the implication here is that **additive floating point attn_mask and additive key_padding_mask to nn.MHA fastpath are broken**
- This wasn't broken for the user in [https://github.com/pytorch/pytorch/issues/107084](https://l.workplace.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fissues%2F107084&h=AT35qHIQavtxKtriTkrkPsWRB3eSRh4qH5PQUyiTzrPTshoztPL0593AmKCmSdEQ5O-5wib0Fd4mwztVu4YbMWb2ghZnZw1pvpJb9-FYWjDsPQ6_oHRVPzFfj8xYXC1TaFnJCkMYjrGXkIfzzxZvmcQYNnIPgsJSiWgjIw) in 1.13.1 because of [this check which bypassed the fast path if attn_mask was defined](https://github.com/pytorch/pytorch/blob/v1.13.1/torch/nn/modules/activation.py#L1096-L1097) (as Driss pointed out though additive key_padding_mask with the fast path were probably broken in 1.13.1)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107641
Approved by: https://github.com/drisspg, https://github.com/jbschlosser