flax
0f631a27 - Fix PRNG handling in `nn.jit` under `nn.scan`.

Commit
1 year ago
Fix PRNG handling in `nn.jit` under `nn.scan`. * `nn.scan` does an abstract eval before compilation to check for constants that are then traced out. Before this change, the abstract eval increments static RNG counters, which creates a side-effect where RNG counters are not properly updated once inner functions are jitted (i.e., under `nn.jit`). * In this fix, we cache the impact a first pass through `nn.scan` and `nn.jit` would have on rng counters and "replay" that impact on subsequent passes so that rng state remains unaffected. * This solution doesn't affect PRNG derivations outside this isolated case and is a placeholder while a more permanent solution, which would affect PRNG derivations, is worked out. PiperOrigin-RevId: 694463650
Author
Committer
Parents
Loading