xla
046f9103 - [Backport] Allow setting FlashAttention's causal mask (#6837)

Commit
1 year ago
[Backport] Allow setting FlashAttention's causal mask (#6837) Summary: This pull request channels the causal mask to our wrapper. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_causal
Author
Parents
Loading