Consider this example:
```
mesh = [('x', 2, Explicit), ('y', 2, 'Auto')]
def f(x: f32[8@x, 2]):
y: f32[8@x, 2] = wsc(x, P('y'))
return y
f(arr: f32[8@x, 2])
```
The with_sharding_constraint should lower to `f32[8@(x,y), 2]` instead of `f32[8@y, 2]` as it does right now. This change fixes that to lower to the former sharding i.e. `f32[8@(x,y), 2]`
**Note: this only affects lowering and not tracing.**
PiperOrigin-RevId: 785474636