xla
fcf24b6c - [Pallas] Make a FlashAttention Wrapper (#6785)

Commit
1 year ago
[Pallas] Make a FlashAttention Wrapper (#6785) Summary: This pull request introduces a FlashAttention wrapper that aims to: 1. Override some default settings for the best performance out of box. 2. Ease the UX such that users don't need to do all the custom_kernel paper works. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper
Author
Parents
Loading