flax
c9252176 - Add tree-mode support for nnx.custom_vjp

Commit
13 days ago
Add tree-mode support for nnx.custom_vjp Enables nnx.custom_vjp to work without the graph protocol by treating Modules as regular JAX pytrees, consistent with how other NNX transforms (jit, grad, vmap, remat) already support tree-mode via the graph parameter. PiperOrigin-RevId: 873092826
Author
Cristian Garcia
Committer
Parents
Loading