jax
cdb466c5 - [attrs] add a jvp function with attrs support

Commit
1 year ago
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
Author
Committer
Parents
Loading