Initial vmap frontend API (#40172)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40172
This PR introduces the initial vmap frontend API. It has the following
limitations that we can resolve in the future:
- the inputs must be a flat list of tensors
- the outputs must be a flat list of tensors
- in_dims = 0 (so we always vmap over dim 0 of input tensors)
- out_dims = 0 (so the returned tensors have their vmap dim appear at
dim 0)
- Coverage limited to operations that have batching rules implemented
(torch.mul, torch.sum, torch.expand).
There are some other semantic limitations (like not being able to handle
mutation, aside from pytorch operations that perform mutation) that will
be documented in the future.
I wanted to introduce the API before adding a slow fallback for the
coverage so that we can test future batching rules (and coverage) via
the python API to avoid verbosity in C++-land.
The way vmap works is that `vmap(func)(inputs)` wraps all Tensor inputs
to be batched in BatchedTensors, sends those into func, and then unwraps
the output BatchedTensors. Operations on BatchedTensors perform the batched
operations that the user is asking for. When performing nested vmaps,
each nested vmap adds a batch dimension upon entry and removes a batch
dimension on exit.
Coming up in the near future:
- Support for non-zero in_dims and out_dims
- docstring for vmap
- slow fallback for operators that do not have a batching rule
implemented.
Test Plan: - `pytest test/test_vmap.py -v`
Differential Revision: D22102076
Pulled By: zou3519
fbshipit-source-id: b119f0a8a3a3b1717c92dbbd180dfb1618295563