Merge all_gather with all_gather_reduced, psum_scatter with unreduced_psum_scatter and psum with unreduced_psum.
Here are the changes:
* all_gather signature gets a `to` argument. `all_gather(x, axis_name, tiled=True, to=...)`. The allowed values are `varying` and `reduced`. `to` defaults to `varying` to preserve the current behavior but you can get `AGR` by specifying `to='reduced'`
* `psum_scatter` will infer the input state from the type. If the input is unreduced over the axis_name, then we will dispatch to `unreduced_psum_scatter_p`. If the input is varying, it will dispatch to `reduce_scatter_p`
* `psum` will infer the input state from the type. If the input is unreduced over the axis_name, then we will dispatch to `unreduced_psum_p`. If the input is varying, it will dispatch to `psum_invariant_p`
PiperOrigin-RevId: 839351465