xla
0c704cf8 - [Pallas] Make FlashAttention as torch.autograd.Function (#6886)

Commit
1 year ago
[Pallas] Make FlashAttention as torch.autograd.Function (#6886) Summary: This pull request makes the flash attention kernel as a torch.autograd.Function such that we can enable backward on the kernel. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_backward
Author
Parents
Loading