Add support for bool/byte `attn_mask` tensor in MultiheadAttention/Transformer modules (#33763)
Summary:
Add the support to accept both float, byte, and bool tensors for `attn_mask`. No breakage is expected.
- If a bool tensor is provided, positions with `True` are not allowed to attend while `False` values will be unchanged.
- if a byte tensor is provided, it will be converted to bool tensor. Positions with non-zero are not allowed to attend while zero values will be unchanged.
- If a float tensor is provided, it will be added to the attention weight.
Note: the behavior of the float mask tensor is slightly different from the first two options because it is added to the attention weight, rather than calling `masked_fill_` function. Also, converting a byte tensor to bool tensor within `multi_head_attention_forward` causes extra overhead. Therefore, a bool mask is recommended here.
For `key_padding_mask`:
- if a bool tensor is provided, it will be converted to bool tensor. The positions with the value of `True` will be ignored while the position with the value of `False` will be unchanged.
- If a byte tensor is provided, the positions with the value of non-zero will be ignored while the position with the value of zero will be unchanged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33763
Differential Revision: D20925358
Pulled By: zhangguanheng66
fbshipit-source-id: de174056be183cdad0f3de8024ee0a3c5eb364c9