jax
11fe9d20 - Allow unreduced/reduced to work when check_vma=False.

Commit
4 days ago
Allow unreduced/reduced to work when check_vma=False. **What does this mean?** When check_vma=False, shard_map's type system won't track unreduced, reduced, invarying and varying anymore. For example: The only way from unreduced -> varying when check_vma=True is via a reduce_scatter but when check_vma=False, this is just a bitcast. Concretely: Assume you have 2 devices: **With check_vma=True:** ``` unreduced --RS--> varying D0 D1 [0 1 | 2 3] ---reduce-scatter---> [2 | 4] ``` Shape changes: * Global shape: `(2,) -> (2,)` * Device local shape: `(2,) -> (1,)` **With check_vma=False:** ``` unreduced ---bitcast---> varying D0 D1 [0 1 | 2 3] ---> [0 1 | 2 3] ``` Shape changes: * Global shape: `(2,) -> (4,)` * Device local shape: `(2,) -> (2,)` The bitcast is similar for replicated <-> unreduced as sharded <-> unreduced. PiperOrigin-RevId: 862901355
Author
Parents
Loading