jax
b9cf0af5 - Implemented cross-host memory transfer on GPU.

Commit
223 days ago
Implemented cross-host memory transfer on GPU. # Background Emily is currently working on extending `jax.device_put` to allow for cross-host memory transfers. Previously (https://github.com/jax-ml/jax/pull/28867), she got the cross-host transfers working on TPU using the `MakeCrossHostReceiveBuffers` and `CopyToRemoteDevice` APIs. This CL implements these two APIs on GPU. # Future Work This CL introduces a very basic, very limited implementation that should be improved in the future. For now, `MakeCrossHostReceiveBuffers` creates a `CliqueId` (think `ncclUniqueId`) and a communicator (think `ncclComm_t`) using this `CliqueId`. The `CliqueId` is sent to the sending process, and the sending process creates the corresponding communicator. The data is then sent using the communicators' `Send` and `Recv` APIs. This design is suboptimal because it creates a new pair of communicators for every transfer. It doesn't use the communicator caching code path that other collectives use. I was also a bit unclear on how memory transfers are ordered. For example, if two different buffers need to be transfered between two devices, don't we need to be careful that the senders and receivers agree on the order in which these buffers will be sent? PiperOrigin-RevId: 771197322
References
Author
Parents
Loading