jax
2c85ca6f - If callback returns a fully replicated global array, return it as is.

Commit
1 year ago
If callback returns a fully replicated global array, return it as is. Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support. PiperOrigin-RevId: 624763603
Author
Committer
Parents
Loading