jax
3a7d9137 - [Pallas TPU] Support ref reshape.

Commit
1 year ago
[Pallas TPU] Support ref reshape. Jaxpr example: ``` { lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:] b[:,:] <- c in () } ``` Tested: - DMA with reshaped ref - Load from reshaped ref - Store to reshaped ref - Multiple transforms - Interpret Mode for ref transforms (updated discharge rules) PiperOrigin-RevId: 686186426
Author
Parents
Loading