[Pallas][jax] Better error message for unexpected types in standard abstract eval
This can happen if a user forgets to unwrap a ref!
@asabne had this happen to him today, and he was confused as to what was going on. The prior error is unclear:
AssertionError: (MemRef<None>{float32[2,1024,1024]}, MemRef<None>{float32[1,1024,1024]})
PiperOrigin-RevId: 749979253