jax
280f6b4e - [pmap] Cache jit wrappers in pmap's shard_map path to reduce dispatch overhead

Commit
2 days ago
[pmap] Cache jit wrappers in pmap's shard_map path to reduce dispatch overhead When jax_pmap_shmap_merge=True, pmap internally uses shard_map via _cached_shard_map. However, infer_params() was calling api.jit() on every single pmap invocation to create a new jit wrapper, adding ~128us of Python overhead per call (80% of the total 159us dispatch overhead). Profiling results on a simple workload (50 iterations of tanh(x@x.T)): Before: pmap exec 0.44ms/iter vs jit+vmap 0.25ms/iter (1.69x slower) After: pmap exec 0.27ms/iter vs jit+vmap 0.23ms/iter (1.18x slower) The fix pre-creates two jit wrappers in _cached_shard_map: 1. jitted_f: without explicit shardings (for tracing context) 2. jitted_f_with_shardings: with in/out shardings (for execution) infer_params now selects the cached variant based on trace_state_clean instead of constructing a new api.jit() wrapper each call. Breakdown of per-call dispatch overhead (before fix): arg_proc (tree_flatten, etc.): 16us (10%) cache lookup (_cached_shard_map): 15us (9%) api.jit() creation: 128us (80%) <-- eliminated by this CL total infer_params: 159us PiperOrigin-RevId: 867697792
Author
Parents
Loading