jax
1f30ced4 - [JAX] Enable `PjRtClient::CopyArraysForCrossHost` and `jax.device_put` to handle arrays whose shards require a mixture of host-local and cross-host transfers.

Commit
22 days ago
[JAX] Enable `PjRtClient::CopyArraysForCrossHost` and `jax.device_put` to handle arrays whose shards require a mixture of host-local and cross-host transfers. Previously, cross-host `device_put` worked only for source/destination shardings for which each shard required a cross-host transfer. This change refactors the single buffer copy logic in `PjRtArray::Copy` into a helper function so it can be called by both `PjRtArray::Copy` and `PjRtClient::CopyArraysForCrossHost` and adds a new test case to `multiprocess/array_test.py`. PiperOrigin-RevId: 857309688
Author
Parents
Loading