pytorch
4809e838 - functorch.jvp support for autograd.Function (#90077)

Commit
3 years ago
functorch.jvp support for autograd.Function (#90077) This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077 Approved by: https://github.com/albanD, https://github.com/soulitzer
Author
Committer
Parents
Loading