jax
a2912b07 - [pmap] Optimize host-local <> global conversions.

Commit
11 days ago
[pmap] Optimize host-local <> global conversions. Eventually, these functions could be put in mhu, but there are a few differences: - Process the entire array of args in one function to avoid O(N) python function calls (maybe small cost). - Use cached in_local/global_shardings (can cache in the mhu versions, but then we have cache look-up costs). - Using arr._rewrap_with_aval_and_sharding (this one can be a drop-in for a fast path). - Cached _local_device_indices values, _aval, and _is_sharding_equivalent (these can be ported over). - Minor optimizations around accessing python attributes (these can be ported over). Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 861971854
Author
Parents
Loading