pytorch
2f37804c - [generate_vmap_rule] Add generate_vmap_rule to autograd.Function (#90966)

Commit
2 years ago
[generate_vmap_rule] Add generate_vmap_rule to autograd.Function (#90966) Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90966 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading