xla
400bd0c9 - [Pallas] Support segment ids in flash attention (#6943)

Commit
1 year ago
[Pallas] Support segment ids in flash attention (#6943) Summary: This PR is to add segment ids to the flash attention wrapper. The segment ids are a way to create an attention mask where each token can only attend to other tokens within the same segment. The mask is therefore a block diagonal matrix. To support it, we further split the flash attention forward into tracing and execution part, and implement all the shape operations to make it compatible with the kernel. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py
Author
Parents
Loading