jax
47bc2f55 - convert NumPy RNG key data to uncommitted default-device-backed `jax.Array` data

Commit
327 days ago
convert NumPy RNG key data to uncommitted default-device-backed `jax.Array` data Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction. Co-authored-by: Yash Katariya <yashkatariya@google.com> PiperOrigin-RevId: 748077900
Author
Parents
Loading