Enable the faster combined weight branch in MHA when query/key/value is same object with nan (#48126)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/47979
For MHA module, it is preferred to use the combined weight branch as much as possible when query/key/value are same (in case of same values by `torch.equal` or exactly same object by `is` ops). This PR will enable the faster branch when a single object with `nan` is passed to MHA.
For the background knowledge
```
import torch
a = torch.tensor([float('NaN'), 1, float('NaN'), 2, 3])
print(a is a) # True
print(torch.equal(a, a)) # False
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48126
Reviewed By: gchanan
Differential Revision: D25042082
Pulled By: zhangguanheng66
fbshipit-source-id: 6bb17a520e176ddbb326ddf30ee091a84fcbbf27