xla
9f2b82dc - [Pallas] Integrate FlashAttention with SPMD (#6935)

Commit
1 year ago
[Pallas] Integrate FlashAttention with SPMD (#6935) Summary: This pull request integrating FlashAttention with SPMD. The way it works is to create a manual sharding region for the kernel which means we wraps all the inputs with enable_manual_sharding and all the outputs with disable_manual_sharding. Added a new test file because the original test file is not SPMD aware. Test Plan: PJRT_DEVICE=TPU python test/test_pallas_spmd.py
Author
Parents
Loading