jax
d1e9445b - Fix rank mismatches in Pallas memory reference transformations.

Commit
49 days ago
Fix rank mismatches in Pallas memory reference transformations. Update `undo_transforms` to track abstract value updates during the forward pass and ensure `get_ref_aval` uses correctly transformed physical avals. This fixes a regression in GLU kernels involving bespoke transform sequences (swizzle, transpose, tile). Also adds a regression test. PiperOrigin-RevId: 881400869
Author
Parents
Loading