pytorch
e8393131 - [generate_vmap_rule] support for jvp (#91211)

Commit
3 years ago
[generate_vmap_rule] support for jvp (#91211) Support for jvp is very similar to support for backward(): - We need to vmap over a version of the original autograd.Function's jvp method that does not take ctx as input. - On the output, we need to reductify to ensure the output tangent has the same shape as the output. This reductify does not have the extra reduction semantics, because PyTorch forward-mode AD requires the output tangent to have the same exact shape as the output. - setup_context needs to tell us the bdims of the saved_tensors (necessary for vmap over jvp_no_context), as well as the output shapes (necessary for reductify). Test Plan: - Added jvp support to the *GenVmapAutogradFunction Pull Request resolved: https://github.com/pytorch/pytorch/pull/91211 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading