jax
262aab74 - canonicalize closed over values if **atleast** 1 mesh axis is `Manual` and **all other mesh axes** are `Manual` or `Auto`. This would make the canonicalization work properly with shmap partial-auto.

Commit
333 days ago
canonicalize closed over values if **atleast** 1 mesh axis is `Manual` and **all other mesh axes** are `Manual` or `Auto`. This would make the canonicalization work properly with shmap partial-auto. If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map. PiperOrigin-RevId: 728956512
Author
Parents
Loading