[Pallas] Make FlashAttention as torch.autograd.Function (#6886)
Summary:
This pull request makes the flash attention kernel as a torch.autograd.Function such that we can enable backward on the kernel.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_backward