jax
a05263f5 - Avoid tuple psum in pmean (#3479)

Commit
5 years ago
Avoid tuple psum in pmean (#3479) We have an optimization to avoid doing flops (or rather, communication) for `psum(1)` (instead we look up the axis size at trace time). But although `pmean(x)`, meaning `psum(x) / psum(1)`, is likely the most common user of `psum(1)`, it doesn't actually trigger this optimization right now because it's implemented as `psum((x, 1))` and `bind` lifts the 1 into the same `JaxprTrace` as `x` rather than letting the psum impl rule see it. The major reason for using tuple psum—providing a fixed order to avoid multihost GPU deadlocks—doesn't apply here because we don't expect the `psum(1)` to lower to an actual XLA AllReduce.
Author
Parents
Loading