jax
81ec7435 - [pmap] Fix PRNGKeyArray type in multihost array conversion.

Commit
68 days ago
[pmap] Fix PRNGKeyArray type in multihost array conversion. In `global_array_to_host_local_array` (and symmetrically in `host_local_array_to_global_array`), when an output is a `PRNGKeyArray`, the code unwraps it to `arr = arr._base_array` (an `ArrayImpl`), but the `typ` variable still holds the original `PRNGKeyArray` type from before unwrapping. This causes the subsequent `if typ == array.ArrayImpl` branch to be skipped, routing the now-unwrapped global `ArrayImpl` into the `else` branch that calls `batched_device_put`. In multihost settings, this global array spans all devices (e.g., 16 shards), and `batched_device_put` expects single-shard arrays, producing: INVALID_ARGUMENT: device_put expects an array with exactly one shard, got an array with with 16 shards. The fix re-evaluates `typ = type(arr)` after the type normalization block so that unwrapped `ArrayImpl` values are correctly handled via the fast `_rewrap_with_aval_and_sharding` path. PiperOrigin-RevId: 869390570
Author
Parents
Loading