Allow use of RngBitGenerator TPU hardware PRNG system.
The old rng_uniform rng op is a side-effecting nondeterministic rng system
that can fail in cases such as grad of scanned-layers for things like dropout. A newer
system exploiting JAX custom PRNGS and the RNGBitGenerator XLA Op can now be used instead
to avoid these pathologies and to make the hardware RNG system more deterministic in nature.
It should be highlighted that this new PRNG system is only deterministic for a given sharding,
and does not preserve random bits across different shardings.
The new hardware RNG system can be enabled by setting the use_hardware_rng=True kwarg on the
top-level train function.
PiperOrigin-RevId: 400667551