[Pallas] Clean up state Transforms
Transforms are simplified to look like this:
```
class Transform(Protocol):
def transform_type(self, x: core.AbstractValue) -> core.AbstractValue:
...
```
In addition, the Mosaic GPU use of transforms has been more "localized"
in that the commutation of transforms is now a Mosaic GPU specific
feature. Previously, the "untransform" methods were operating on MLIR
values. They are now "commute" methods that operate on JAX values.
We also remove MemoryRefTransform in favor of just using Transform.
PiperOrigin-RevId: 866688134