[generate_vmap_rule] add restore_vmap helper function (#90963)
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
`restore_vmap` is a private helper function. It is vmap but has the
following
differences:
- instead of returning outputs, it returns an (outputs, out_dims) tuple.
out_dims is a pytree of shape shape as outputs and contains Optional[int]
specifying where the vmapped dimension, if it exists, is in the
corresponding output.
- does no validation on in_dims or inputs (vmap expects at least one
Tensor to be vmapped).
restore_vmap allows for no inputs to have the vmap dimension
- does no validation on outputs (vmap expects only Tensor outputs)
restore_vmap allows for return of arbitrary outputs (not just
Tensors)
Test Plan:
- added some simple test to test restore_vmap
- I am OK with restore_vmap not being a part of vmap right now -- the
implementation of vmap rarely changes and it is a bit difficult to
refactor vmap in a way that restore_vmap is a subroutine.
Other questions:
- Bikeshedding the `restore_vmap` name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90963
Approved by: https://github.com/samdow, https://github.com/soulitzer