CUDA reduction: allow outputs to have different strides (#42649)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42364
Benchmark:
https://github.com/zasdfgbnm/things/blob/master/2020Q3/min-benchmark.ipynb
```python
import torch
print(torch.__version__)
print()
for i in range(100):
torch.randn(1000, device='cuda')
for e in range(7, 15):
N = 2 ** e
input_ = torch.randn(N, N, device='cuda')
torch.cuda.synchronize()
%timeit input_.min(dim=0); torch.cuda.synchronize()
input_ = torch.randn(N, N, device='cuda').t()
torch.cuda.synchronize()
%timeit input_.min(dim=0); torch.cuda.synchronize()
print()
```
Before
```
1.7.0a0+5d7c3f9
21.7 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.6 µs ± 773 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
22.5 µs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.2 µs ± 250 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
26.4 µs ± 67 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.9 µs ± 316 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
33 µs ± 474 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
21.1 µs ± 218 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
84.2 µs ± 691 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
50.3 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
181 µs ± 2.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
145 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
542 µs ± 753 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
528 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.04 ms ± 9.74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.01 ms ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
After
```
1.7.0a0+9911817
21.4 µs ± 695 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.6 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
22.4 µs ± 153 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.5 µs ± 58.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
26.6 µs ± 147 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.9 µs ± 675 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
35.4 µs ± 560 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
21.7 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
86.5 µs ± 1.99 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
52.2 µs ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
195 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
153 µs ± 4.46 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
550 µs ± 7.72 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
527 µs ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.05 ms ± 7.87 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2 ms ± 4.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42649
Reviewed By: ezyang
Differential Revision: D22994446
Pulled By: ngimel
fbshipit-source-id: cc60beebad2e04c26ebf3ca702a6cb05846522c9