pytorch
3f25fc3b - Adds 3D attn_mask support to merge_masks() for Multihead Attention fast path (#98991) (#99092)

Commit
2 years ago
Adds 3D attn_mask support to merge_masks() for Multihead Attention fast path (#98991) (#99092) Fixes #97409 Adds support for 3D attn_mask by always expanding attn_mask to 4D as per https://github.com/pytorch/pytorch/pull/98375#issuecomment-1499504721 Pull Request resolved: https://github.com/pytorch/pytorch/pull/98991 Approved by: https://github.com/jbschlosser
Author
Parents
Loading