flax
3f2ff40e
- [pmap] Update `jax_utils.unreplicate` for performant use when `jax.config.jax_pmap_shmap_merge` is `True`.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
88 days ago
[pmap] Update `jax_utils.unreplicate` for performant use when `jax.config.jax_pmap_shmap_merge` is `True`. For more details, see https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays PiperOrigin-RevId: 848531449
Author
danielsuo
Committer
a-googler
Parents
d1791844
Loading