jax
5ab714bd - Preserve the sharding provided in `out_sharding` argument of `auto_axes`/or any other API when the computation has a single device in explicit mode

Loading