flax
8e51b714 - [JAX] Replace uses of jax.experimental.pjit.with_sharding_constraint with jax.lax.with_sharding_constraint.

Commit
2 years ago
[JAX] Replace uses of jax.experimental.pjit.with_sharding_constraint with jax.lax.with_sharding_constraint. This API has graduated from experimental status; use the non-experimental name. No functional changes intended. PiperOrigin-RevId: 553879678
Author
Committer
Parents
Loading