jax
a0f8a4aa - [Pallas:MGPU] Add an all-reduce mode to the reduce-scatter kernel

Commit
149 days ago
[Pallas:MGPU] Add an all-reduce mode to the reduce-scatter kernel The same strategy we used before (ld_reduce) works, we just need to issue more instructions. PiperOrigin-RevId: 819267594
Author
Parents
Loading