[JAX] Use `xla::ifrt::Client::MakeArraysFromHostBufferShards()` in Array creation when possible
This changes makes use of the new
`xla::ifrt::Client::MakeArraysFromHostBufferShards()` API when possible. This
API needs a single call to create a multi-shard IFRT Array (to be wrapped as a
JAX `PyArray`), which provides more optimization opportunities for the runtime
than creating single-device IFRT Arrays and then assembling them. Please note
that `xla::ifrt::Client::MakeArraysFromHostBufferShards()` implementation in
PjRt-IFRT is not yet optimized, so there is no immediate performance benefits
for McJAX.
As an exception, it takes the previous path of array assembly if any shard for
`BatchedDevicePut` is not a host buffer, but already a single-device array,
because `xla::ifrt::Client::MakeArraysFromHostBufferShards()` works only if all
the sharded input to be host buffers.
With batching possible at IFRT level, we now skip `DevicePutResultFn` step;
`DevicePut` (now `DevicePutWithDevice` and `DevicePutWithSharding`) internally
calls per-shard functions (with GIL released) and returns a final IFRT Array.
This change includes a code cleanup for
`xla::DevicePutResult::owning_pybuffer`, which was originally intended to hold
a Python object to keep an IFRT Array valid when it is created from
`DevicePut()` implementations, but this role has been entirely covered by
`on_done_with_host_buffer` function supplied at IFRT Array creation time.
PiperOrigin-RevId: 749989229