pytorch
e2b4c63d - Enable the faster combined weight branch in MHA when query/key/value is same object with nan (#48126)

Commit
4 years ago
Enable the faster combined weight branch in MHA when query/key/value is same object with nan (#48126) Summary: Fixes https://github.com/pytorch/pytorch/issues/47979 For MHA module, it is preferred to use the combined weight branch as much as possible when query/key/value are same (in case of same values by `torch.equal` or exactly same object by `is` ops). This PR will enable the faster branch when a single object with `nan` is passed to MHA. For the background knowledge ``` import torch a = torch.tensor([float('NaN'), 1, float('NaN'), 2, 3]) print(a is a) # True print(torch.equal(a, a)) # False ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/48126 Reviewed By: gchanan Differential Revision: D25042082 Pulled By: zhangguanheng66 fbshipit-source-id: 6bb17a520e176ddbb326ddf30ee091a84fcbbf27
Author
Guanheng Zhang
Parents
Loading