jax
250e2ee7 - Use the mesh of `out_aval` when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.

Commit
314 days ago
Use the mesh of `out_aval` when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct. Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now. PiperOrigin-RevId: 729307115
Author
Parents
Loading