jax
b08a1052 - Add a test for `out: f32[8@x] = add/mul(x: f32[8@x], y: f32[8]{R:x})` in Manual and Explicit mode. Make sure the backward pass is also as expected.

Commit
179 days ago
Add a test for `out: f32[8@x] = add/mul(x: f32[8@x], y: f32[8]{R:x})` in Manual and Explicit mode. Make sure the backward pass is also as expected. There is a difference in how this works in both modes: * Manual (shmap) mode: `shard_map` adds a `reduced -> varying` cast automatically so the `add/mul` ends up being `add/mul(x: f32[8}{V:x}, y_: f32[8]{V:x})` * Explicit mode: While shmap does this automatically, `Explicit` mode errors out and requires users to insert a reshard to go from `f32[8]{R:x} -> f32[8@x]`. We can do this automatically too just like shmap but that's left for a future change. Note that for `replicated`, this reshard is built into `mul/add`. Also, improve error message for gather sharding ambiguity. PiperOrigin-RevId: 836333704
Author
Parents
Loading