Migrate flax from using old-style PRNG keys to new-style typed PRNG keys
Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at https://github.com/google/jax/pull/17297, and please feel free to offer comments and feedback!
PiperOrigin-RevId: 565475405