pytorch
31981d01 - [generate_vmap_rule] add restore_vmap helper function (#90963)

Commit
2 years ago
[generate_vmap_rule] add restore_vmap helper function (#90963) As seen in https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit `restore_vmap` is a private helper function. It is vmap but has the following differences: - instead of returning outputs, it returns an (outputs, out_dims) tuple. out_dims is a pytree of shape shape as outputs and contains Optional[int] specifying where the vmapped dimension, if it exists, is in the corresponding output. - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped). restore_vmap allows for no inputs to have the vmap dimension - does no validation on outputs (vmap expects only Tensor outputs) restore_vmap allows for return of arbitrary outputs (not just Tensors) Test Plan: - added some simple test to test restore_vmap - I am OK with restore_vmap not being a part of vmap right now -- the implementation of vmap rarely changes and it is a bit difficult to refactor vmap in a way that restore_vmap is a subroutine. Other questions: - Bikeshedding the `restore_vmap` name Pull Request resolved: https://github.com/pytorch/pytorch/pull/90963 Approved by: https://github.com/samdow, https://github.com/soulitzer
Author
Committer
Parents
Loading