jax
4452dc5b - [Mosaic GPU] Add support for block-scaled f4e2m1fn MMAs

Commit
258 days ago
[Mosaic GPU] Add support for block-scaled f4e2m1fn MMAs Turns out that the `f8f6f4` MMA variant always takes in K=32, no matter what operand bitwidth is used. To make matters worse, it requires each group of 32 elements to be padded in SMEM to 32 bytes, which means that we'd be wasting half the space for f4. That's why it makes more sense to focus on the block-scaled `mxf4` MMA variant, which properly supports K=64 and doesn't require any padding. PiperOrigin-RevId: 780163721
Author
Parents
Loading