jax
de8e4e89 - [pmap] Remove the `jax_pmap_no_rank_reduction` config state.

Commit
144 days ago
[pmap] Remove the `jax_pmap_no_rank_reduction` config state. The `jax_pmap_no_rank_reduction` flag was deprecated in JAX v0.7.2 and defaulted to `True`. This change removes the flag entirely, making the no-rank-reduction behavior the only supported behavior: a `jax.pmap`ped function `f` sees inputs of the same rank as the input to `jax.pmap(f)`. For example, if `jax.pmap(f)` receives shape `(8, 128)` on 8 devices, then `f` receives shape `(1, 128)`. PiperOrigin-RevId: 856272216
Author
Parents
Loading