Add the infer_params cache back that was removed during the lu.cache -> weakref_lru_cache move.
**Why do we need this?**
When there are nested jits, jit generally drops out of cpp cache because it does not know about Tracers and enters python land. Now, there is a lot of slow python code in _trace_for_jit (like debug_info calculation, various checks, etc) and we run it even if we get a tracing cache hit!
The infer_params cache is designed to basically skip all the slow python stuff if the `fun, jit_info, arg_signature (dynamic_arg_treedefs, static_args, static_argnames, dynamic_arg_names), avals, ctx_mesh` are the same as before.
For example:
```
@jax.jit
def f(x):
x = x + 1
x = x + 1
x = x + 1
return x
f(jnp.arange(8))
```
In this example, `+` is jitted (so we have nested jits) and we can skip all the slow python work for 2 out of the 3 `+`.
PiperOrigin-RevId: 863387193