jax
cbd0844e - [Mosaic GPU] Fold non-gather indices into the gather TMA column index

Commit
126 days ago
[Mosaic GPU] Fold non-gather indices into the gather TMA column index The offset is shared by all gather indices so this method reduces the overall amount of arithmetic instructions we emit. For as long as we don't add support for inputs that are more than 2D this essentially undoes any of the ops inserted by the tiling transforms (and so the only downside of tiling is that we might use more TMA instructions, but not more arithmetic). This implementation is simpler too. PiperOrigin-RevId: 796798377
Author
Parents
Loading