66037d10 - Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to https://github.com/jax-ml/jax/commit/bcd4048dd5aa1ed8da39f6be88b4f1f3715f77ec
Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to https://github.com/jax-ml/jax/commit/bcd4048dd5aa1ed8da39f6be88b4f1f3715f77ec
PiperOrigin-RevId: 729532213