flax
3f2ff40e - [pmap] Update `jax_utils.unreplicate` for performant use when `jax.config.jax_pmap_shmap_merge` is `True`.

Commit
88 days ago
[pmap] Update `jax_utils.unreplicate` for performant use when `jax.config.jax_pmap_shmap_merge` is `True`. For more details, see https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays PiperOrigin-RevId: 848531449
Author
Committer
Parents
Loading