jax
f15f9717 - [Pallas/TPU] Fix bug with LocalMask grid shrinking

Commit
1 year ago
[Pallas/TPU] Fix bug with LocalMask grid shrinking LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space. As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space. PiperOrigin-RevId: 655901432
Committer
Parents
Loading