jax
0f262f22 - Allow reduce_sum to accept unreduced inputs. reduce_sum on a sharded input consists of 2 steps; (1) local reduction and (2) all-reduce across devices.

Commit
25 days ago
Allow reduce_sum to accept unreduced inputs. reduce_sum on a sharded input consists of 2 steps; (1) local reduction and (2) all-reduce across devices. A simple example: ``` inp: f32[4]{U:x} out: f32[]{U:x} = reduce_sum(inp, axes=[0]) ``` We just propagate the unreduced from input to output. This means that we'll just do the local sum and skip the AR. More examples: ``` inp = f32[8@x, 4]{U:y} # case 1 jax.lax.reduce_sum(inp, axes=[0]) -> f32[4]{U:y} # case 2 jax.lax.reduce_sum(inp, axes=[0], out_sharding=P(unreduced={'x', 'y'}) -> f32[4]{U:(x,y)} # case 3 jax.lax.reduce_sum(inp, axes=[0], out_sharding=P(unreduced={'x'}) -> f32[4]{U:x} ``` PiperOrigin-RevId: 866217452
Author
Parents
Loading