Allow vmap to accept nested python data structures as inputs (#46289)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46289
Previously, vmap had the restriction that any Tensors in the inputs must
not be a part of a nested python collection. This PR relaxes that
restriction. We can also do the same thing for vmap outputs, but I'll
leave that for future work
The mechanism behind vmap is to convert any Tensor inputs (that have
been specified via in_dims) into BatchedTensor. Using a pytree
implementation, that logic becomes:
- flatten inputs
- broadcast in_dims to inputs and unflatten it
- use the flat inputs and flat in_dims to construct BatchedTensors
- unflatten the BatchedTensors into the same structure as the original
inputs.
- Send the unflattened BatchedTensors into the desired function.
Performance
-----------
Some benchmarking using
```
import torch
def foo(a, b, c, d):
return a, b, c, d
x = torch.randn(2, 3)
foo_vmap = torch.vmap(foo)
%timeit foo_vmap(x, x, x, x)
```
shows a slowdown from 15us to 25us on my machine. The 10us overhead is
not a lot, especially since our vmap implementation is a "prototype". We
can work around the performance in the future by either moving part of
the pytree implementation into C++ or depending on a library that has a
performant pytree implementation.
Test Plan
---------
- New tests, also updated old tests.
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D24392892
Pulled By: zou3519
fbshipit-source-id: 072b21dcc6065ab43cfd341e84a01a5cc8ec3daf