jax
66037d10 - Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to https://github.com/jax-ml/jax/commit/bcd4048dd5aa1ed8da39f6be88b4f1f3715f77ec

Commit
319 days ago
Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to https://github.com/jax-ml/jax/commit/bcd4048dd5aa1ed8da39f6be88b4f1f3715f77ec PiperOrigin-RevId: 729532213
Author
Parents
Loading