enable headdims > 64 for flash attention on sm90 (#99776)
Follow up to #99105 which disabled FlashAttention when using autograd and mem eff attention for the following cases
head_dim > 64
sm86 or newer
We have tested enabling FlashAttention on sm90 and it works, so this PR will enable it back for sm90 and add in a test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99776
Approved by: https://github.com/malfet, https://github.com/drisspg