jax
f73d8c93 - Enter into auto mode for `.at(...).add(...)` and set the `out_sharding` (on auto_axes) to be the same as the input sharding. Fixes https://github.com/jax-ml/jax/issues/29654

Commit
209 days ago
Enter into auto mode for `.at(...).add(...)` and set the `out_sharding` (on auto_axes) to be the same as the input sharding. Fixes https://github.com/jax-ml/jax/issues/29654 PiperOrigin-RevId: 776744299
References
Author
Parents
Loading