jax
8f21ab30 - [Pallas:MGPU] Use a much better matmul kernel in the collective matmul

Commit
122 days ago
[Pallas:MGPU] Use a much better matmul kernel in the collective matmul Turns out it wasn't the collective part that was holding us back, but the matmul part. Now that we have a really good matmul kernel, we can simply plug it into the collective loop, and add a tiny part that does the sends. In my simple benchmarking setup it already seems to always beat the NCCL+cuBLAS baseline within a single host. PiperOrigin-RevId: 810794450
References
Author
Parents
Loading