flax
5ff36ba4 - Add simple argument-only lifted nn.grad and nn.value_and_grad functions.

Commit
2 years ago
Add simple argument-only lifted nn.grad and nn.value_and_grad functions. This function only peforms a lifted value-and-grad operation with respect to the arguments of a function, and does not try to calculate gradients with respect to any variables. This mirrors the behavior of the haiku grad function, and also easily works in the multi-scope setting (external modules passed in) while avoiding the complexities associated with that case for a more general vjp.
Author
Committer
Parents
Loading