jax
695eac8e - [Pallas:MGPU] Fix scratch size for cross-warp reductions.

Commit
21 days ago
[Pallas:MGPU] Fix scratch size for cross-warp reductions. We need `vector_length * 128` (number of lanes) where `vector_length` can be > 2. PiperOrigin-RevId: 856694658
Author
Parents
Loading