flax
f18baf07 - Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad.

Commit
329 days ago
Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad. Passing a truthy value for this argument will cause JAX to raise an error. It looks like this has been the case for a little more than a year -- see https://github.com/jax-ml/jax/pull/19970 . PiperOrigin-RevId: 739299262
Author
Committer
Parents
Loading