Disable SDPA FlashAttention backward and mem eff attention on sm86+ for head_dim above 64 (#99105)
Expand sdpa_utils.h check to disable FlashAttention when using autograd and mem eff attention for the following cases
- head_dim > 64
- sm86 or newer
Previously we only disable these kernels on sm86 and for head_dim equal to 128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99105
Approved by: https://github.com/malfet