jax
139f212f - [Mosaic GPU] Fix handling of sliced/transposed types in `transformed_smem_ref_type`.

Commit
91 days ago
[Mosaic GPU] Fix handling of sliced/transposed types in `transformed_smem_ref_type`. Previously, we would fail to produce correct strides and offsets in cases involving slicing, leading to verifier errors. Concretely, slicing does not affect strides---but we would recompute the strides solely based on the sliced row-major logical shape of the ref. Contrarily to lane lowering, where tiling transforms are always applied prior to slicing or transposition when lowering, warpgroup lowering works with sliced or transposed refs, and must apply tiling transforms on them a posteriori. Our function is now able to tile arbitrary memref types. PiperOrigin-RevId: 863479995
Author
Parents
Loading