pytorch
aaaf2eb6 - Add batching rule for torch.sum(tensor, dims) (#39581)

Commit
4 years ago
Add batching rule for torch.sum(tensor, dims) (#39581) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39581 Context: Batching rules ------------------------------------ Batching rules take BatchedTensors and regular Tensors as arguments. A batching rule generally does the following: 1. Converts (logical) BatchedTensors to views on physical tensors. 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical arguments that correspond to the physical tensors. 3. Calls at:: operations on the physical tensors and arguments. 4. Converts physical results back to BatchedTensors. Steps 1 and 2 differ for operators with different batching behaviors. (see next section) VmapTransform abstraction ------------------------------------ (Previously known as a "Converter". Bikeshedding welcome on the naming). An ArgTransform converts logical views of tensors to physical views. When writing a batching rule, users should select the ArgTransform that matches the batching behavior of their operator. If the batching behavior of the op is complicated, then they’ll have to write some custom logic (either by writing a new ArgTransform, or writing the logical->physical transform themselves). *56% (~474) of (vmap-supported) operators can and will use these VmapTransform. 20% (~168) of operators need custom handling*. See `VmapTransforms.h` for more context. PhysicalView ------------------------------------ VmapTransforms return physical views on tensors, represented by the PhysicalView struct. It is effectively a Tensor and contains enough metadata to enable mapping logical non-tensor arguments to physical non-tensor arguments, and the other way around. There are two methods on PhysicalView right now: - `PhysicalView::getPhysicalDim(logical_dim)` and `PhysicalView::getPhysicalDims(logical_dims)`. are used to map logical dims to physical dims. - `PhysicalView::newLogicalFromPhysical(Tensor)` is used to map a result physical tensor from a batching rule to a logical tensor (BatchedTensor). Test Plan: ------------------------------------ - `./build/bin/vmap_test` Differential Revision: D21983789 Pulled By: zou3519 fbshipit-source-id: dc558e05b596fd29f9643e933e4ece4b7866b6db
Author
Parents
Loading