jax
3b6fd04e - Propagate input sharding to output in full_like if it's fully replicated. Then even if input/output shapes don't match, it should be fine. Fixes https://github.com/jax-ml/jax/issues/35273

Commit
66 days ago
Propagate input sharding to output in full_like if it's fully replicated. Then even if input/output shapes don't match, it should be fine. Fixes https://github.com/jax-ml/jax/issues/35273 PiperOrigin-RevId: 873111177
Author
Parents
Loading