flax
5d908519 - [pmap] Avoid degraded performance under the new `jax.pmap`.

Commit
47 days ago
[pmap] Avoid degraded performance under the new `jax.pmap`. This change prepares for the new `jax.pmap` by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with `jax.pmap` is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard. The change is under the `jax_pmap_shmap_merge` configuration flag. If `True`, the new `jax.pmap` implementation based on `jax.jit(jax.shard_map)` is used and requires the new explicit shard access. If `False`, the old `jax.pmap` implementation is used and there is a special case in how `x[0]` works. Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays PiperOrigin-RevId: 852526292
Author
Committer
Parents
Loading