pytorch
727463a7 - Initial vmap frontend API (#40172)

Commit
4 years ago
Initial vmap frontend API (#40172) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40172 This PR introduces the initial vmap frontend API. It has the following limitations that we can resolve in the future: - the inputs must be a flat list of tensors - the outputs must be a flat list of tensors - in_dims = 0 (so we always vmap over dim 0 of input tensors) - out_dims = 0 (so the returned tensors have their vmap dim appear at dim 0) - Coverage limited to operations that have batching rules implemented (torch.mul, torch.sum, torch.expand). There are some other semantic limitations (like not being able to handle mutation, aside from pytorch operations that perform mutation) that will be documented in the future. I wanted to introduce the API before adding a slow fallback for the coverage so that we can test future batching rules (and coverage) via the python API to avoid verbosity in C++-land. The way vmap works is that `vmap(func)(inputs)` wraps all Tensor inputs to be batched in BatchedTensors, sends those into func, and then unwraps the output BatchedTensors. Operations on BatchedTensors perform the batched operations that the user is asking for. When performing nested vmaps, each nested vmap adds a batch dimension upon entry and removes a batch dimension on exit. Coming up in the near future: - Support for non-zero in_dims and out_dims - docstring for vmap - slow fallback for operators that do not have a batching rule implemented. Test Plan: - `pytest test/test_vmap.py -v` Differential Revision: D22102076 Pulled By: zou3519 fbshipit-source-id: b119f0a8a3a3b1717c92dbbd180dfb1618295563
Author
Parents
Loading