Implement multiple dispatch (#25653)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25653
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id.
Billing of changes:
* ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.)
* Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments.
* The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check.
After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse.
* Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++.
* One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it.
* A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message)
* `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch.
* `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity.
* c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely.
Benchmark:
Apply the following patch to the base commit and this commit:
```
diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp
new file mode 100644
index 0000000000..b66f4d3ece
--- /dev/null
+++ b/aten/src/ATen/native/Const.cpp
@@ -0,0 +1,10 @@
+#include <ATen/ATen.h>
+
+namespace at {
+namespace native {
+
+Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) {
+ return self;
+}
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b494ed7950..fddae638bb 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5878,3 +5878,9 @@
dispatch:
CPU: im2col_backward_cpu
CUDA: im2col_backward_cuda
+
+# For benchmarking
+- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor
+ variants: function
+ dispatch:
+ CPU: _const5
```
Comparisons with timeit:
One-argument, representative case:
Before:
```
In [6]: %timeit x.reshape(1, 1)
1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [7]: %timeit x.reshape(1, 1)
1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [8]: %timeit x.reshape(1, 1)
1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
After:
```
In [3]: %timeit x.reshape(1, 1)
1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit x.reshape(1, 1)
1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit x.reshape(1, 1)
1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments):
Before:
```
In [1]: import torch
In [2]: x = torch.zeros(1)
In [3]: %timeit torch._const5(x, x, x, x, x)
949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit torch._const5(x, x, x, x, x)
954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit torch._const5(x, x, x, x, x)
947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
After:
```
In [3]: %timeit torch._const5(x, x, x, x, x)
985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit torch._const5(x, x, x, x, x)
984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit torch._const5(x, x, x, x, x)
988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D17265918
Pulled By: ezyang
fbshipit-source-id: 221efe4e86a40f36abc81e2ebceaa7e251c90b3d