jax
b97ad423 - Consider this example:

Commit
209 days ago
Consider this example: ``` mesh = [('x', 2, Explicit), ('y', 2, 'Auto')] def f(x: f32[8@x, 2]): y: f32[8@x, 2] = wsc(x, P('y')) return y f(arr: f32[8@x, 2]) ``` The with_sharding_constraint should lower to `f32[8@(x,y), 2]` instead of `f32[8@y, 2]` as it does right now. This change fixes that to lower to the former sharding i.e. `f32[8@(x,y), 2]` **Note: this only affects lowering and not tracing.** PiperOrigin-RevId: 785474636
Author
Parents
Loading