flax
6444f97e - Fix RNG splitting in core lift functions for KeyArrays.

Commit
4 years ago
Fix RNG splitting in core lift functions for KeyArrays. JAX KeyArrays are themselves pytree containers, so for the moment pending further improvements we need to map the jax.random.split/fold_in/etc functions with a tree_map marking KeyArrays as leaves. PiperOrigin-RevId: 401326381
Author
Committer
Parents
Loading