Allow resharding between tokens on a single device
and multiple devices.
Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.
Fixes https://github.com/jax-ml/jax/issues/25671.
PiperOrigin-RevId: 716309978