jax
170718c8 - Change signature of linearization rules.

Commit
1 year ago
Change signature of linearization rules. Give the rule the nonzero tangent pattern up-front. This is needed to make a linearization rule for pjit_p. Also make the rules return the nonzero tangents out, an explicit residual, and a closed tangent function. Add a rule for sin_p to test it out. We still need to figure out how to avoid having to precompute `cos(x)`. I think we need to update our backward pass code.
Author
Committer
Parents
Loading