Vectorize on output for reduction kernels (#37206)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37206
Benchmark on P100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark-vectorize-output.ipynb
```python
import torch
print(torch.__version__)
print()
for i in range(1000):
torch.arange(10000, device='cuda')
def benchmark(dtype, i):
size0 = 2 ** (i // 2)
size1 = 2 ** ((i + 1) // 2)
a = torch.zeros(size0, size1, device='cuda', dtype=dtype)
torch.cuda.synchronize()
%timeit a.sum(dtype=dtype, dim=0); torch.cuda.synchronize()
for dtype in [torch.int8, torch.half, torch.float, torch.double]:
print(dtype)
for i in range(18, 30):
benchmark(dtype, i)
print()
```
Before
```
1.5.0a0+3bbb36e
torch.int8
24.5 µs ± 111 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
24.1 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
26.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30.9 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
39 µs ± 504 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
59.6 µs ± 244 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
111 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
186 µs ± 300 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
397 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
665 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.45 ms ± 837 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.03 ms ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
torch.float16
24.2 µs ± 66.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
24.6 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
27.2 µs ± 53.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32 µs ± 91 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
48.1 µs ± 89.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66.9 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
121 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
218 µs ± 384 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
431 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
854 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.75 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.63 ms ± 849 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
torch.float32
24.2 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
24.4 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
29.3 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
40.5 µs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
57.4 µs ± 44.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
85.5 µs ± 41.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
288 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
557 µs ± 904 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1e+03 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.98 ms ± 533 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.8 ms ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
torch.float64
25 µs ± 54.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
26.9 µs ± 320 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
37.1 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
54.3 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
84.9 µs ± 65.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
139 µs ± 68.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
275 µs ± 235 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
504 µs ± 702 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
987 µs ± 613 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.84 ms ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.64 ms ± 2.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.19 ms ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
After
```
1.5.0a0+3bbb36e
torch.int8
29.8 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30.7 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33.4 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32.5 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
40.6 µs ± 94.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
53.7 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68 µs ± 69.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
98.2 µs ± 88.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
283 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
522 µs ± 563 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
967 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
torch.float16
29.4 µs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
29.2 µs ± 45.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30.8 µs ± 41 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
35.3 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
50.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
70.4 µs ± 67.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
101 µs ± 325 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
157 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
275 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
486 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
936 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.85 ms ± 124 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
torch.float32
29.9 µs ± 36.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
29.5 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33 µs ± 93.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46 µs ± 37.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
64 µs ± 73.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
99.4 µs ± 82.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
157 µs ± 74.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
265 µs ± 68.8 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
490 µs ± 319 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
960 µs ± 669 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.84 ms ± 632 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.6 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
torch.float64
33.1 µs ± 74.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
36.7 µs ± 86.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.7 µs ± 39.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
61.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
100 µs ± 23.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 202 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
270 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
491 µs ± 445 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
939 µs ± 339 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.88 ms ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.65 ms ± 5.18 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.3 ms ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
Test Plan: Imported from OSS
Differential Revision: D21233255
Pulled By: ngimel
fbshipit-source-id: d468fddbb228c0c13146dfc6344c470513f9e374