jax
a8592fe0 - [Pallas:MGPU] Add initial support for TMEM under WG semantic.

Commit
48 days ago
[Pallas:MGPU] Add initial support for TMEM under WG semantic. We introduce a new op to the MGPU dialect: `slice_tmem`. This op extracts a new contiguous TMEM reference from the original one at a given offset. The resulting reference can have a different layout and element type than the original one. This is necessary to share most of the code between Pallas LANE and WG lowering. For LANE lowering, Pallas issues a single TMEM allocation and slice it into smaller ones. We also add Pallas lowering for `async_load_tmem`, `async_store_tmem` and `commit_tmem` in the same change to be able to run basic tests. PiperOrigin-RevId: 827506566
Author
Parents
Loading