Allow reduce_sum to accept unreduced inputs. reduce_sum on a sharded input consists of 2 steps; (1) local reduction and (2) all-reduce across devices.
A simple example:
```
inp: f32[4]{U:x}
out: f32[]{U:x} = reduce_sum(inp, axes=[0])
```
We just propagate the unreduced from input to output. This means that we'll just do the local sum and skip the AR.
More examples:
```
inp = f32[8@x, 4]{U:y}
# case 1
jax.lax.reduce_sum(inp, axes=[0]) -> f32[4]{U:y}
# case 2
jax.lax.reduce_sum(inp, axes=[0], out_sharding=P(unreduced={'x', 'y'}) -> f32[4]{U:(x,y)}
# case 3
jax.lax.reduce_sum(inp, axes=[0], out_sharding=P(unreduced={'x'}) -> f32[4]{U:x}
```
PiperOrigin-RevId: 866217452