flax
851b585c - Add graph=False support for nnx.grad and nnx.value_and_grad

Commit
1 day ago
Add graph=False support for nnx.grad and nnx.value_and_grad Add tree-mode support to nnx.grad/nnx.value_and_grad, where Variables are treated as pytree leaves and argnums are passed directly to JAX. PiperOrigin-RevId: 868389994
Author
Cristian Garcia
Committer
Parents
Loading