xla
c54367c8 - [Pallas] Support Flash Attention backward kernels (#6870)

Commit
1 year ago
[Pallas] Support Flash Attention backward kernels (#6870) Summary: This changes refactors custom_kernel.py to support all three new kernels from Pallas that are involved in Flash Attention backward calculations. The refactoring includes: 1. Adds support for static_argnums which will ignore some positional arguments for jax tracing. 2. Separate jax tracing part out such that we can do the tracing alone. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py
Author
Parents
Loading