jax
f2f552c1 - Allow resharding between tokens on a single device

Commit
1 year ago
Allow resharding between tokens on a single device and multiple devices. Whenever this happens we can essentially introduce an effects barrier instead of doing the normal device -> host -> device transfer. Fixes https://github.com/jax-ml/jax/issues/25671. PiperOrigin-RevId: 716309978
Author
Parents
Loading