Drop into `Auto` mode for `.at[...].set(...)` but instead of taking an `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.
Fixes https://github.com/jax-ml/jax/issues/28111
PiperOrigin-RevId: 749089846