jax
b8ecb137 - [pmap] Inline `multihost_utils` `impl` functions for host_local <> global conversions, `_shared_code_pmap`, and `_prepare_pmap`.

Commit
21 days ago
[pmap] Inline `multihost_utils` `impl` functions for host_local <> global conversions, `_shared_code_pmap`, and `_prepare_pmap`. When `core.trace_state_clean()`, call the `mhu` implementations directly to avoid binding overheads. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 859185228
Author
Parents
Loading