jax
48140bcb - Add basic support of `unreduced` to sharding-in-types! We cannot lower it right now, but it atleast shows up in types.

Commit
264 days ago
Add basic support of `unreduced` to sharding-in-types! We cannot lower it right now, but it atleast shows up in types. The API to specify unreduced is via `PartitionSpec`. For example: `PartitionSpec('x', 'y', None, unreduced='z')` or `PartitionSpec('x', unreduced=('y', 'z'))`. In types/jaxpr, unreduced will show up as: `f32[8@x,2]{U:y}` But we support unreduced only in dot_general and nary ops (add, mul, etc) as of this change: (the support will be expanded in following changes) * **dot general** only allows unreduced when contracting dims are sharded. And the unreduced axes specified by the user needs to match the sharding of the contracting dims. In all other cases, an error is raised. An example of how unreduced can be specified: `jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced='y'))` * **nary ops** can propagate unreduced (add, mul, etc). If all ops aren't unreduced across the same mesh axes, an error is raised. PiperOrigin-RevId: 756063074
References
Author
Parents
Loading