flax
Use jax.api.device_put_sharded() in place of private JAX APIs.
#466
Merged

Loading