jax
d0b6eb2e - [Pallas][jax] Better error message for unexpected types in standard abstract eval

Commit
290 days ago
[Pallas][jax] Better error message for unexpected types in standard abstract eval This can happen if a user forgets to unwrap a ref! @asabne had this happen to him today, and he was confused as to what was going on. The prior error is unclear: AssertionError: (MemRef<None>{float32[2,1024,1024]}, MemRef<None>{float32[1,1024,1024]}) PiperOrigin-RevId: 749979253
Parents
Loading