[pmap] Avoid degraded performance under the new `jax.pmap`.
This change prepares for the new `jax.pmap` by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with `jax.pmap` is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard.
The change is under the `jax_pmap_shmap_merge` configuration flag. If `True`, the new `jax.pmap` implementation based on `jax.jit(jax.shard_map)` is used and requires the new explicit shard access. If `False`, the old `jax.pmap` implementation is used and there is a special case in how `x[0]` works.
Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays
PiperOrigin-RevId: 852526292