[pmap] Speed up the new `jax.pmap` implementation.
This change:
- Passes `jax/tests:pmap_test` and `jax/tests/multiprocess:pmap_test`.
- Caches the pmap'd function on the same set of inputs as the previous version of the new `jax.pmap`.
Optimizations and clean-ups:
- Build `core.DebugInfo` and `linear_util.WrappedFun` once when the `jax.pmap` transformation is invoked, not each time when the pmap'd function is called.
- Call impl-style functions for host<>global array conversions directly when `core.trace_state_clean()` (i.e., `jax.pmap` is the top-level transformation) to avoid binding overhead.
- Cache `avals`, `sharding`, and other per-array calls when converting between host-local arrays and global arrays.
- Expose a no data move/copy sharding metadata rewrap.
- Remove all `linear_util` wrapping that use `linear_util.Store` to avoid `StoreException`s.
What's missing
- The `.lower()` path will not behave correctly when further calling `.compile()(*inputs)`. However, the previous version of `jit(shard_map)` `jax.pmap` had the same issue. Updated documentation to discuss this.
Differences w.r.t. multihost_util's host-local <> global conversion functions:
- Process the entire array of args in one function to avoid O(N) python function calls (maybe small cost).
- Use cached in_local/global_shardings (can cache in the mhu versions, but then we have cache look-up costs).
- Using `arr._rewrap_with_aval_and_sharding` (this one can be a drop-in for a fast path).
- Cached `_local_device_indices` values, `_aval`, and `_is_sharding_equivalent` (these can be ported over).
- Minor optimizations around accessing python attributes (these can be ported over).
- Implements handling for donated args in both fast and slow paths. This might require some API changes to mhu.
PiperOrigin-RevId: 862035366