jax
f6101063 - [Pallas] Clean up state Transforms

Commit
79 days ago
[Pallas] Clean up state Transforms Transforms are simplified to look like this: ``` class Transform(Protocol): def transform_type(self, x: core.AbstractValue) -> core.AbstractValue: ... ``` In addition, the Mosaic GPU use of transforms has been more "localized" in that the commutation of transforms is now a Mosaic GPU specific feature. Previously, the "untransform" methods were operating on MLIR values. They are now "commute" methods that operate on JAX values. We also remove MemoryRefTransform in favor of just using Transform. PiperOrigin-RevId: 866688134
Author
Parents
Loading