jax
f346fd09 - [JAX] Use `xla::ifrt::Client::MakeArraysFromHostBufferShards()` in Array creation when possible

Commit
251 days ago
[JAX] Use `xla::ifrt::Client::MakeArraysFromHostBufferShards()` in Array creation when possible This changes makes use of the new `xla::ifrt::Client::MakeArraysFromHostBufferShards()` API when possible. This API needs a single call to create a multi-shard IFRT Array (to be wrapped as a JAX `PyArray`), which provides more optimization opportunities for the runtime than creating single-device IFRT Arrays and then assembling them. Please note that `xla::ifrt::Client::MakeArraysFromHostBufferShards()` implementation in PjRt-IFRT is not yet optimized, so there is no immediate performance benefits for McJAX. As an exception, it takes the previous path of array assembly if any shard for `BatchedDevicePut` is not a host buffer, but already a single-device array, because `xla::ifrt::Client::MakeArraysFromHostBufferShards()` works only if all the sharded input to be host buffers. With batching possible at IFRT level, we now skip `DevicePutResultFn` step; `DevicePut` (now `DevicePutWithDevice` and `DevicePutWithSharding`) internally calls per-shard functions (with GIL released) and returns a final IFRT Array. This change includes a code cleanup for `xla::DevicePutResult::owning_pybuffer`, which was originally intended to hold a Python object to keep an IFRT Array valid when it is created from `DevicePut()` implementations, but this role has been entirely covered by `on_done_with_host_buffer` function supplied at IFRT Array creation time. PiperOrigin-RevId: 749989229
Author
Parents
Loading