pytorch
e0950fcc - [SDPA] Add expanded autograd testing for fused kernels and disable head_dim128 sm86 mem-efficient (#94009)

Commit
1 year ago
[SDPA] Add expanded autograd testing for fused kernels and disable head_dim128 sm86 mem-efficient (#94009) # Summary - Adds a large parameter sweep for testing the various configs a user can call sdpa with and compares the deviation of the fused kernels vs the eager math fallback to test for correctness. - Sm86 + head_dim==128 is throwing an IMA for memory efficient attention. We add a filter for use_mem_efficient_attention(). This has since been fixed in the upstream Xformers version but will likely not make it for branch cut. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94009 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading