flax
c2ba4a38 - Use jax.api.device_put_sharded() in place of private JAX APIs.

Commit
5 years ago
Use jax.api.device_put_sharded() in place of private JAX APIs. PiperOrigin-RevId: 332141521
Author
Jake VanderPlas
Committer
Parents
Loading