f73d8c93 - Enter into auto mode for `.at(...).add(...)` and set the `out_sharding` (on auto_axes) to be the same as the input sharding. Fixes https://github.com/jax-ml/jax/issues/29654
Enter into auto mode for `.at(...).add(...)` and set the `out_sharding` (on auto_axes) to be the same as the input sharding. Fixes https://github.com/jax-ml/jax/issues/29654
PiperOrigin-RevId: 776744299