jax
cccc34dc - Raise an error if the type passed to `axis_types` argument of `Mesh` and `AbstractMesh` is not `jax.sharding.AxisType`.

Commit
328 days ago
Raise an error if the type passed to `axis_types` argument of `Mesh` and `AbstractMesh` is not `jax.sharding.AxisType`. PiperOrigin-RevId: 744602037
Author
Parents
Loading