xla
de42834d
- [Pallas] Allow setting FlashAttention's causal mask (#6792)
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
1 year ago
[Pallas] Allow setting FlashAttention's causal mask (#6792) Summary: This pull request channels the causal mask to our wrapper. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_causal
References
#6792 - [Pallas] Allow setting FlashAttention's causal mask
Author
alanwaketan
Parents
8d579f91
Loading