pytorch
80614783 - Enabling FlashAttention for SDPA when given NestedTensor (#95438)

Commit
2 years ago
Enabling FlashAttention for SDPA when given NestedTensor (#95438) # Summary Previously, for NestedTensor inputs flash_attention was disabled due to an Illegal Memory Access error that was occurring on the "cutlass" branch of flash-attention that had be incorporated into core. Since we have switched to the main branch of flash_attention we the existing repro script did not produce the same memory error. This PR re-enables the FlashAttention Path for NTs. As well it unifies the nested preprocessing between the two implementations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95438 Approved by: https://github.com/mikaylagawarecki
Author
Committer
Parents
Loading