pytorch
35c6547f - Adds 3D attn_mask support to merge_masks() for Multihead Attention fast path (#98991)

Commit
1 year ago
Adds 3D attn_mask support to merge_masks() for Multihead Attention fast path (#98991) Fixes #97409 Adds support for 3D attn_mask by always expanding attn_mask to 4D as per https://github.com/pytorch/pytorch/pull/98375#issuecomment-1499504721 Pull Request resolved: https://github.com/pytorch/pytorch/pull/98991 Approved by: https://github.com/jbschlosser
Author
Committer
Parents
Loading