jax
cd513f25 - [Pallas:MGPU] Perform a warpgroup barrier before and after every SMEM write

Commit
262 days ago
[Pallas:MGPU] Perform a warpgroup barrier before and after every SMEM write This allows us to guarantee the single-thread sequential semantics that we want to see in Pallas, even if it sometimes goes a little overboard with the barriers. However, there are situations when both are necessary! We barrier before we overwrite memory to ensure that all the warps are done reading from it before we do so. Conversely, we barrier after the store to make sure its effects are visible by reads issued from all other warps in the same Pallas thread (i.e. the warpgroup). I hope this should not lead to significant performance problems, since we generally only write from registers to SMEM once in the whole kernel (in the epilogue), and we usually had to perform a warpgroup barrier there too (as well as the async proxy fence). PiperOrigin-RevId: 755285818
References
Author
Parents
Loading