jax
bcead8e8 - Make `.sharding` attribute of ShapeDtypeStruct dynamic.

Commit
67 days ago
Make `.sharding` attribute of ShapeDtypeStruct dynamic. If a `PartitionSpec` is passed to SDS constructor, the mesh is bound at call time (of accessing `.sharding`) instead of at construction time (calling SDS constructor). Basically this allows you to write code like this: ``` # No error with this change! Before this change, it would have error saying there is # no context mesh, so you can pass pspec to SDS. sds = jax.ShapeDtypeStruct(shape, dtype, sharding=P('x')) with jax.set_mesh(jax.make_mesh((2,), ('x',))): out = sds.sharding assert out.sharding == NamedSharding(mesh, P('x')) with jax.set_mesh(jax.make_mesh((2,), ('y',))) # This will error since we will create # NamedSharding(mesh_with_y_axis_name, P('x')) _ = sds.sharding ``` This dynamic pattern in used in some of the APIs like `jax.sharding.auto_axes` and `jax.sharding.explicit_axes` where the mesh is bound at call time instead of creation time. PiperOrigin-RevId: 818695338
Author
Parents
Loading