jax
854b2c85 - Drop into `Auto` mode for `.at[...].set(...)` but instead of taking an `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Commit
324 days ago
Drop into `Auto` mode for `.at[...].set(...)` but instead of taking an `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes https://github.com/jax-ml/jax/issues/28111 PiperOrigin-RevId: 749089846
Author
Parents
Loading