Fix PRNG handling in `nn.jit` under `nn.scan`.
* `nn.scan` does an abstract eval before compilation to check for constants that are then traced out. Before this change, the abstract eval increments static RNG counters, which creates a side-effect where RNG counters are not properly updated once inner functions are jitted (i.e., under `nn.jit`).
* In this fix, we cache the impact a first pass through `nn.scan` and `nn.jit` would have on rng counters and "replay" that impact on subsequent passes so that rng state remains unaffected.
* This solution doesn't affect PRNG derivations outside this isolated case and is a placeholder while a more permanent solution, which would affect PRNG derivations, is worked out.
PiperOrigin-RevId: 694463650