jax
c9de52fa - [Pallas:MGPU] Add support for slicing WGMMA accumulator under WG semantics.

Commit
32 days ago
[Pallas:MGPU] Add support for slicing WGMMA accumulator under WG semantics. We lower WGMMA accumulator slicing to: ``` acc_slice = vector.extract_strided_slice(acc, ...) new_acc = mgpu.dialect.wgmma(acc_slice, ...) acc = vector.insert_strided_slice(new_acc, acc, ...) ``` PiperOrigin-RevId: 903245417
Author
Parents
Loading