jax
f021cfa9 - [Pallas:MGPU] Only infer vec_size=1 in the reduce scatter kernel for integer types

Commit
127 days ago
[Pallas:MGPU] Only infer vec_size=1 in the reduce scatter kernel for integer types Vectorized multimem reductions are not supported for integers. PiperOrigin-RevId: 819284336
Author
Parents
Loading