jax
2b7b0742 - [Pallas TPU] Add lowerings for bf16 `jnp.ceil` and `jnp.floor` in TPU v6+

Commit
1 year ago
[Pallas TPU] Add lowerings for bf16 `jnp.ceil` and `jnp.floor` in TPU v6+ This PR is similar to https://github.com/jax-ml/jax/pull/24284 Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support. PiperOrigin-RevId: 688581914
Author
Parents
Loading