Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad.
Passing a truthy value for this argument will cause JAX to raise an error. It looks like this has been the case for a little more than a year -- see https://github.com/jax-ml/jax/pull/19970 .
PiperOrigin-RevId: 739299262