[pmap] Optimize host-local <> global conversions.
Eventually, these functions could be put in mhu, but there are a few differences:
- Process the entire array of args in one function to avoid O(N) python function calls (maybe small cost).
- Use cached in_local/global_shardings (can cache in the mhu versions, but then we have cache look-up costs).
- Using arr._rewrap_with_aval_and_sharding (this one can be a drop-in for a fast path).
- Cached _local_device_indices values, _aval, and _is_sharding_equivalent (these can be ported over).
- Minor optimizations around accessing python attributes (these can be ported over).
Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`.
PiperOrigin-RevId: 861971854