[Pallas] Support Flash Attention backward kernels (#6870)
Summary:
This changes refactors custom_kernel.py to support all three new kernels from Pallas that are involved in Flash Attention backward calculations.
The refactoring includes:
1. Adds support for static_argnums which will ignore some positional arguments for jax tracing.
2. Separate jax tracing part out such that we can do the tracing alone.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py