flax
851b585c
- Add graph=False support for nnx.grad and nnx.value_and_grad
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
1 day ago
Add graph=False support for nnx.grad and nnx.value_and_grad Add tree-mode support to nnx.grad/nnx.value_and_grad, where Variables are treated as pytree leaves and argnums are passed directly to JAX. PiperOrigin-RevId: 868389994
References
test_868389994
#5240 - Add graph=False support for nnx.grad and nnx.value_and_grad
Author
Cristian Garcia
Committer
a-googler
Parents
0bfe50a5
Loading