Migrate addmv and mv from legacy to ATen native (CUDA & CPU) (#30898)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/24605 https://github.com/pytorch/pytorch/issues/24535 https://github.com/pytorch/pytorch/issues/24739 https://github.com/pytorch/pytorch/issues/24680 https://github.com/pytorch/pytorch/issues/30986
This does not fix https://github.com/pytorch/pytorch/issues/29984, it will be fixed in later PR.
Most of this PR is just following the same logic inside TH and THC except the handle of n-dimensional zero-sized tensor, in specific the case:
```
(m,).addmv((m, 0), (0,), beta, alpha)
```
# Legacy code bugs and how this PR deal with it
The above case is a case where BLAS often have a mismatch of semantics with PyTorch: For BLAS and cuBLAS, the above is a noop, but for PyTorch, it is a scalar-vector multiplication `output = beta * input`. The handle of this case is already very poor in legacy code and it is poorly tested:
For the CPU implementation, there are two code paths:
- Path 1: when dtype is float or double and `USE_BLAS`, then use BLAS
- Path 2: when other dtypes or not `USE_BLAS`, use a fallback kernel in PyTorch
For the CUDA implementation, there are also two code paths:
- Path 1: when float or double, then use `cublasSgemv` or `cublasDgemv` in cuBlas
- Path 2: when half, dispatch to `addmm`
`test_blas_alpha_beta_empty` is supposed to cover all cases, but unfortunately, it only tests the Path 1 of CUDA and Path 1 of CPU, and both uncovered paths (path 2 for CPU and path 2 for CUDA) are buggy in legacy code. In this PR, I expanded the coverage of `test_blas_alpha_beta_empty`, but unfortunately, I have to skip the `half` dtype on CUDA 9. See the description below for detail:
## Bug on CPU implementation
For the CPU implementation, the fallback kernel in path 2 already has the same semantics as PyTorch, not BLAS. But the code that tries to correct BLAS semantics to match PyTorch also runs on this case, leading to double correction, that is, `output = beta * input` now becomes `output = beta * beta * input`.
This leads to the issue https://github.com/pytorch/pytorch/issues/30986 I just opened, and it is fixed in this PR.
## Bug on CUDA implementation
For the CUDA implementation, path 2 dispatches to
```
(m, 1).addmm((m, 0), (0, 1), beta, alpha)
```
But unfortunately, for some old CUDA version when on old GPU on half dtype, the above is also noop, which is definitely not correct.
But from what I see, on newer CUDA version or newer GPU, this is not a problem. This is a bug of PyTorch in `addmm`, so I opened a new issue https://github.com/pytorch/pytorch/issues/31006 to track this problem. But this is highly likely a dependency bug for PyTorch originating from cuBLAS, and it is only on a rarely used edge case on old hardware and software, so this issue would be a `won't_fix` unless some real requirements strongly indicate that this should be fixed.
This issue is already with legacy code, and this PR does not make it worse. To prevent this issue from bothering us, I disable the test of `half` dtype for CUDA 9 when expanding the coverage of `test_blas_alpha_beta_empty`.
I promote a CircleCI CUDA 10.1 test to `XImportant` so that it runs on PRs, because the path 2 of CUDA implementation is only covered by this configuration. Let me know if I should revert this change.
## An additional problem
In legacy code for `addmv`, dtype `bfloat16` is enabled and dispatch to `addmm`, but `addmm` does not support `bfloat16` from what I test. I do the same thing in the new code. Let me know if I should do it differently.
# Benchmark
Code:
```python
import torch
print(torch.__version__)
for i in range(1000):
torch.arange(i, device='cuda')
print('cpu')
for i in 10, 100, 1000, 10000:
a = torch.randn((i,))
b = torch.randn((i, i))
c = torch.randn((i,))
%timeit a.addmv(b, c, alpha=1, beta=2)
print('cuda')
for i in 10, 100, 1000, 10000:
a = torch.randn((i,)).cuda()
b = torch.randn((i, i)).cuda()
c = torch.randn((i,)).cuda()
torch.cuda.synchronize()
%timeit a.addmv(b, c, alpha=1, beta=2); torch.cuda.synchronize()
```
Before:
```
1.5.0a0+2b45368
cpu
2.74 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
8.5 µs ± 85.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
686 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
74 ms ± 410 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cuda
The slowest run took 4.81 times longer than the fastest. This could mean that an intermediate result is being cached.
27.6 µs ± 23 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
17.3 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
20.5 µs ± 369 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
756 µs ± 6.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
After:
```
1.5.0a0+66b4034
cpu
3.29 µs ± 20 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.09 µs ± 7.41 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
687 µs ± 7.01 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
73.8 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cuda
18.2 µs ± 478 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
17.7 µs ± 299 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
21.5 µs ± 2.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
751 µs ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30898
Differential Revision: D20660338
Pulled By: anjali411
fbshipit-source-id: db1f521f124198f63545064026f93fcb16b68f18