jax
a6f03d3d - [Mosaic GPU] Execute relayouts in the lowering of `layout_cast`.

Commit
349 days ago
[Mosaic GPU] Execute relayouts in the lowering of `layout_cast`. As I was rewriting Pallas tests, I noticed a case where we essentially would generate code with the form: ``` x = create() # out_layout: strided layout y = layout_cast x # {in,out}_layout = WGMMA layout z = consume(y) # in_layout: strided layout ``` which would end up silently succeeding. Since the intent is for `layout_cast` to be used as a user annotation to guide layout inference, we should make sure to crash in such cases---when a bad relayout is involved, and the layout we intended to use is not propagated. PiperOrigin-RevId: 773592922
Author
Parents
Loading