pytorch
81bbee7d - [SDPA] Adds basic correctness checks (#94274)

Commit
1 year ago
[SDPA] Adds basic correctness checks (#94274) # Summary Add more checks around shape constraints as well as update the sdp_utils to properly catch different head_dims between qk and v for flash_attention which is not supported. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94274 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading