jax
66dbacad - Add DontWant and DidntWant (the best names ever)

Commit
3 days ago
Add DontWant and DidntWant (the best names ever) **Why?** Because we want a way to say that we don't want a gradient for this primal. This will also allow us to skip work in transpose_fancy rules when the operand is a NullAccum. **What's the API look like now?** After you get the `f_vjp` from `jax.vjp`, you can call `with_refs` with these values: `f_vjp.with_ref(x: Ref | DontWant | GradValue) : (GradRef | DidntWant | Jax.Array)` Below is the mapping from user input -> internal accum -> output type. * Ref (pytree) -- RefAccum -- GradRef (sentinal) * DontWant (sentinal) -- NullAccum -- DidntWant (sentinal) * GradValue (sentinal) -- ValAccum -- jax.Array (pytree) Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 879182894
Author
Parents
Loading