Replace uses of deprecated jax.numpy functions:
- jnp.alltrue -> jnp.all
- jnp.sometrue -> jnp.any
- jnp.cumproduct -> jnp.cumprod
- jnp.product -> jnp.prod
These have been deprecated in JAX following similar deprecations in NumPy v1.25.0.
PiperOrigin-RevId: 537794332