DeepSpeed
0e0748c5 - adds triton flash attention2 kernel (#4337)

Commit
2 years ago
adds triton flash attention2 kernel (#4337) * initial commit * temp commit: needs debugging * packed flash attn with mask works * clean-up * add bert/roberta tests to test_inference * is_triton_supported added to Accelerator class clean-up and formatting * triton supports the flash attention when compute cap > 8.0 * formatting * fix comments * cleanup * cleanup flash kernel * fix according to the PR comment --------- Co-authored-by: Stephen Youn <styoun@microsoft.com> Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Author
Parents
Loading