pytorch
fdeee436 - Disable SDPA FlashAttention backward and mem eff attention on sm86+ for head_dim above 64 (#99105)

Commit
2 years ago
Disable SDPA FlashAttention backward and mem eff attention on sm86+ for head_dim above 64 (#99105) Expand sdpa_utils.h check to disable FlashAttention when using autograd and mem eff attention for the following cases - head_dim > 64 - sm86 or newer Previously we only disable these kernels on sm86 and for head_dim equal to 128. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99105 Approved by: https://github.com/malfet
Author
Committer
Parents
Loading