[pmap] Remove any wrappings that have auxiliary values.
We want to avoid StoreException or StoreEmpty errors. Previous code added another transformation to reset stores. New code does away with this and also gives a chance to remove some unnecessary flatten/unflattens.
Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`.
PiperOrigin-RevId: 861864367