jax
ac1a53d8 - Optimize `_create_copy_plan` by iterating over only the shards that are needed for materialization

Commit
1 year ago
Optimize `_create_copy_plan` by iterating over only the shards that are needed for materialization For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX. The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`. PiperOrigin-RevId: 624969222
Author
Committer
Parents
Loading