pytorch
b10c53e9 - Vectorize on output for reduction kernels (#37206)

Commit
4 years ago
Vectorize on output for reduction kernels (#37206) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37206 Benchmark on P100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark-vectorize-output.ipynb ```python import torch print(torch.__version__) print() for i in range(1000): torch.arange(10000, device='cuda') def benchmark(dtype, i): size0 = 2 ** (i // 2) size1 = 2 ** ((i + 1) // 2) a = torch.zeros(size0, size1, device='cuda', dtype=dtype) torch.cuda.synchronize() %timeit a.sum(dtype=dtype, dim=0); torch.cuda.synchronize() for dtype in [torch.int8, torch.half, torch.float, torch.double]: print(dtype) for i in range(18, 30): benchmark(dtype, i) print() ``` Before ``` 1.5.0a0+3bbb36e torch.int8 24.5 µs ± 111 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 24.1 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 26.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 30.9 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 39 µs ± 504 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 59.6 µs ± 244 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 111 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 186 µs ± 300 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 397 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 665 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.45 ms ± 837 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.03 ms ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float16 24.2 µs ± 66.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 24.6 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 27.2 µs ± 53.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 32 µs ± 91 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 48.1 µs ± 89.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 66.9 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 121 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 218 µs ± 384 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 431 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 854 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.75 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.63 ms ± 849 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float32 24.2 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 24.4 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 29.3 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 40.5 µs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 57.4 µs ± 44.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 85.5 µs ± 41.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 158 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 288 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 557 µs ± 904 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1e+03 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.98 ms ± 533 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.8 ms ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float64 25 µs ± 54.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 26.9 µs ± 320 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 37.1 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 54.3 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 84.9 µs ± 65.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 139 µs ± 68.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 275 µs ± 235 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 504 µs ± 702 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 987 µs ± 613 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.84 ms ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.64 ms ± 2.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.19 ms ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` After ``` 1.5.0a0+3bbb36e torch.int8 29.8 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 30.7 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 33.4 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 32.5 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 40.6 µs ± 94.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 53.7 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 68 µs ± 69.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 98.2 µs ± 88.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 158 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 283 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 522 µs ± 563 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 967 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) torch.float16 29.4 µs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 29.2 µs ± 45.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 30.8 µs ± 41 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 35.3 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 50.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 70.4 µs ± 67.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 101 µs ± 325 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 157 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 275 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 486 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 936 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.85 ms ± 124 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) torch.float32 29.9 µs ± 36.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 29.5 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 33 µs ± 93.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 46 µs ± 37.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 64 µs ± 73.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 99.4 µs ± 82.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 157 µs ± 74.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 265 µs ± 68.8 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 490 µs ± 319 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 960 µs ± 669 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.84 ms ± 632 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.6 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float64 33.1 µs ± 74.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 36.7 µs ± 86.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 46.7 µs ± 39.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 61.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 100 µs ± 23.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 158 µs ± 202 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 270 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 491 µs ± 445 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 939 µs ± 339 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.88 ms ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.65 ms ± 5.18 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.3 ms ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` Test Plan: Imported from OSS Differential Revision: D21233255 Pulled By: ngimel fbshipit-source-id: d468fddbb228c0c13146dfc6344c470513f9e374
Author
Parents
Loading