Add batching rule for torch.sum(tensor, dims) (#39581)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39581
Context: Batching rules
------------------------------------
Batching rules take BatchedTensors and regular Tensors as arguments. A
batching rule generally does the following:
1. Converts (logical) BatchedTensors to views on physical tensors.
2. Converts logical arguments (e.g. dimension indexes, shapes) to
physical arguments that correspond to the physical tensors.
3. Calls at:: operations on the physical tensors and arguments.
4. Converts physical results back to BatchedTensors.
Steps 1 and 2 differ for operators with different batching behaviors.
(see next section)
VmapTransform abstraction
------------------------------------
(Previously known as a "Converter". Bikeshedding welcome on the naming).
An ArgTransform converts logical views of tensors to physical views.
When
writing a batching rule, users should select the ArgTransform that
matches
the batching behavior of their operator. If the batching behavior of the
op is complicated, then they’ll have to write some custom logic (either
by writing a new ArgTransform, or writing the logical->physical
transform
themselves).
*56% (~474) of (vmap-supported) operators can and will use these
VmapTransform. 20% (~168) of operators need custom handling*.
See `VmapTransforms.h` for more context.
PhysicalView
------------------------------------
VmapTransforms return physical views on tensors, represented by the
PhysicalView struct. It is effectively a Tensor and contains
enough metadata to enable mapping logical non-tensor arguments to
physical non-tensor arguments, and the other way around.
There are two methods on PhysicalView right now:
- `PhysicalView::getPhysicalDim(logical_dim)` and
`PhysicalView::getPhysicalDims(logical_dims)`.
are used to map logical dims to physical dims.
- `PhysicalView::newLogicalFromPhysical(Tensor)` is used to map a result
physical tensor from a batching rule to a logical tensor
(BatchedTensor).
Test Plan:
------------------------------------
- `./build/bin/vmap_test`
Differential Revision: D21983789
Pulled By: zou3519
fbshipit-source-id: dc558e05b596fd29f9643e933e4ece4b7866b6db