flax
7b20ed2d - Allow use of RngBitGenerator TPU hardware PRNG system.

Commit
4 years ago
Allow use of RngBitGenerator TPU hardware PRNG system. The old rng_uniform rng op is a side-effecting nondeterministic rng system that can fail in cases such as grad of scanned-layers for things like dropout. A newer system exploiting JAX custom PRNGS and the RNGBitGenerator XLA Op can now be used instead to avoid these pathologies and to make the hardware RNG system more deterministic in nature. It should be highlighted that this new PRNG system is only deterministic for a given sharding, and does not preserve random bits across different shardings. The new hardware RNG system can be enabled by setting the use_hardware_rng=True kwarg on the top-level train function. PiperOrigin-RevId: 400667551
References
Author
Committer
Parents
Loading