xla
de42834d - [Pallas] Allow setting FlashAttention's causal mask (#6792)

Commit
1 year ago
[Pallas] Allow setting FlashAttention's causal mask (#6792) Summary: This pull request channels the causal mask to our wrapper. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_causal
Author
Parents
Loading