jax
896b6e83 - [Mosaic GPU] Add support for reductions across a subset of warps

Commit
263 days ago
[Mosaic GPU] Add support for reductions across a subset of warps This is a new capability in the warp reduction code that used to raise a `NotImplementedError`. I've also added a check that none of the warp dims are replicated (we were missing it before), since that will require a similar but slightly different handling. The reason our test didn't catch this before is because we never sampled multiple warp dims, and `max` is not the best function because including the replicated parts doesn't change the output. This is why I also changed the test to now use integer add, which is associative (so we can continue to test without worrying about error accumulation) and would catch this problem. PiperOrigin-RevId: 780562520
Author
Parents
Loading