jax
4cda3a9d - [Pallas/Mosaic GPU] Implement logic that allows commuting `ReshapeTransform`s with `UntilingTransform`s.

Commit
31 days ago
[Pallas/Mosaic GPU] Implement logic that allows commuting `ReshapeTransform`s with `UntilingTransform`s. For now, we only support commuting reshapes that act only as folds (and not unfolds). It is illegal to fold untiled dimensions with tiled dimensions. In order to handle deprecating the old path gracefully in Tokamax kernels, we temporarily add a special pattern match to WGMMA lowering. Once the change has propagated, and the relevant kernel changed appropriately, we will be able to remove the special paths and simply set `_handle_reshapes` to `True` when handling transforms. PiperOrigin-RevId: 907007299
Author
Parents
Loading