[pmap] Cache jit wrappers in pmap's shard_map path to reduce dispatch overhead
When jax_pmap_shmap_merge=True, pmap internally uses shard_map via
_cached_shard_map. However, infer_params() was calling api.jit() on
every single pmap invocation to create a new jit wrapper, adding
~128us of Python overhead per call (80% of the total 159us dispatch
overhead).
Profiling results on a simple workload (50 iterations of tanh(x@x.T)):
Before: pmap exec 0.44ms/iter vs jit+vmap 0.25ms/iter (1.69x slower)
After: pmap exec 0.27ms/iter vs jit+vmap 0.23ms/iter (1.18x slower)
The fix pre-creates two jit wrappers in _cached_shard_map:
1. jitted_f: without explicit shardings (for tracing context)
2. jitted_f_with_shardings: with in/out shardings (for execution)
infer_params now selects the cached variant based on trace_state_clean
instead of constructing a new api.jit() wrapper each call.
Breakdown of per-call dispatch overhead (before fix):
arg_proc (tree_flatten, etc.): 16us (10%)
cache lookup (_cached_shard_map): 15us (9%)
api.jit() creation: 128us (80%) <-- eliminated by this CL
total infer_params: 159us
PiperOrigin-RevId: 867697792