jax
b582e359 - Add the infer_params cache back that was removed during the lu.cache -> weakref_lru_cache move.

Commit
93 days ago
Add the infer_params cache back that was removed during the lu.cache -> weakref_lru_cache move. **Why do we need this?** When there are nested jits, jit generally drops out of cpp cache because it does not know about Tracers and enters python land. Now, there is a lot of slow python code in _trace_for_jit (like debug_info calculation, various checks, etc) and we run it even if we get a tracing cache hit! The infer_params cache is designed to basically skip all the slow python stuff if the `fun, jit_info, arg_signature (dynamic_arg_treedefs, static_args, static_argnames, dynamic_arg_names), avals, ctx_mesh` are the same as before. For example: ``` @jax.jit def f(x): x = x + 1 x = x + 1 x = x + 1 return x f(jnp.arange(8)) ``` In this example, `+` is jitted (so we have nested jits) and we can skip all the slow python work for 2 out of the 3 `+`. PiperOrigin-RevId: 863387193
Author
Parents
Loading