jax
e598bc8e - [PJRT PLUGIN] Pass in a blocking key value get and a key value put function pointer instead of a DistributedRuntimeClient pointer when creating a GPU client.

Commit
2 years ago
[PJRT PLUGIN] Pass in a blocking key value get and a key value put function pointer instead of a DistributedRuntimeClient pointer when creating a GPU client. This is to support multi-process/multi-node GPU PJRT plugin. A DistributedRuntimeClient pointer should not be passed through the C API boundary. Therefore a key value get and key value put function pointer is provided by the framework. This change focuses on changes related to the C++ GPU client. C API related changes will be a follow up change. This change includes: - Use kv_get and kv_put in NCCL id. - The lead node (node_id 0) uses kv_get and kv_put to generate and put the global topology in se_gpu_pjrt_client. Other nodes use kv_get to get the global topology. - Use kv_get, kv_put and number of nodes when creating a StreamExecutorGpuClient. kv_get and kv_put can be generated from DistributedRuntimeClient. However, DistributedRuntimeClient and DistributedRuntimeService does not expose number of nodes. Currently it will be obtained from distributed.global_state. - Modify xla.cc to create kv_get and kv_put from DistributedRuntimeClient. - Modify xla_bridge to pass in num_nodes. - Change call sites of GetStreamExecutorGpuClient. Most call sites use a nullptr DistributedRuntimeClient and it is a no-op for them. PiperOrigin-RevId: 538845061
Author
Jieying Luo
Committer
Parents
Loading