jax
ed952c8e - [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.

Commit
1 year ago
[Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector. PiperOrigin-RevId: 722015152
Author
Parents
Loading