jax
cf763131 - [pmap] Speed up the new `jax.pmap` implementation.

Commit
27 days ago
[pmap] Speed up the new `jax.pmap` implementation. This change: - Passes `jax/tests:pmap_test` and `jax/tests/multiprocess:pmap_test`. - Caches the pmap'd function on the same set of inputs as the previous version of the new `jax.pmap`. Optimizations and clean-ups: - Build `core.DebugInfo` and `linear_util.WrappedFun` once when the `jax.pmap` transformation is invoked, not each time when the pmap'd function is called. - Call impl-style functions for host<>global array conversions directly when `core.trace_state_clean()` (i.e., `jax.pmap` is the top-level transformation) to avoid binding overhead. - Cache `avals`, `sharding`, and other per-array calls when converting between host-local arrays and global arrays. - Expose a no data move/copy sharding metadata rewrap. - Remove all `linear_util` wrapping that use `linear_util.Store` to avoid `StoreException`s. What's missing - The `.lower()` path will not behave correctly when further calling `.compile()(*inputs)`. However, the previous version of `jit(shard_map)` `jax.pmap` had the same issue. Updated documentation to discuss this. Differences w.r.t. multihost_util's host-local <> global conversion functions: - Process the entire array of args in one function to avoid O(N) python function calls (maybe small cost). - Use cached in_local/global_shardings (can cache in the mhu versions, but then we have cache look-up costs). - Using `arr._rewrap_with_aval_and_sharding` (this one can be a drop-in for a fast path). - Cached `_local_device_indices` values, `_aval`, and `_is_sharding_equivalent` (these can be ported over). - Minor optimizations around accessing python attributes (these can be ported over). - Implements handling for donated args in both fast and slow paths. This might require some API changes to mhu. PiperOrigin-RevId: 862035366
Author
Parents
Loading