jax
dc11d402 - [Pallas TPU] Better error message for lowering `sp.broadcast_to_p`

Commit
1 year ago
[Pallas TPU] Better error message for lowering `sp.broadcast_to_p` `sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this. As an example, currently, users will hit this error when doing: ``` def kernel(x_ref, o_ref): m, n = 32, 8 x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None])) o_ref[...] = x ``` PiperOrigin-RevId: 700290975
Author
Parents
Loading