Fix RNG splitting in core lift functions for KeyArrays.
JAX KeyArrays are themselves pytree containers, so for the moment
pending further improvements we need to map the jax.random.split/fold_in/etc
functions with a tree_map marking KeyArrays as leaves.
PiperOrigin-RevId: 401326381