[Pallas] Integrate FlashAttention with SPMD (#6935)
Summary:
This pull request integrating FlashAttention with SPMD. The way it works is to create a manual sharding region for the kernel which means we wraps all the inputs with enable_manual_sharding and all the outputs with disable_manual_sharding.
Added a new test file because the original test file is not SPMD aware.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas_spmd.py