pytorch
383e4510 - Implement torch.sort with cub::DeviceSegmentedRadixSort (#56821)

Commit
3 years ago
Implement torch.sort with cub::DeviceSegmentedRadixSort (#56821) Summary: Benchmark: ```python import torch import itertools for i in range(1000): torch.arange(100000, device='cuda') def run50_sync(f): for _ in range(50): f() torch.cuda.synchronize() for i, j in itertools.product([512, 4096, 8192], repeat=2): print(i,j) t = torch.randn(i, j, device='cuda') torch.cuda.synchronize() %timeit run50_sync(lambda: torch.sort(t)) torch.cuda.synchronize() %timeit run50_sync(lambda: torch.sort(t, dim=0)) print() ``` Before ``` 512 512 4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 512 4096 40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 512 8192 71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 512 27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 4096 262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 4096 8192 520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 512 54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 8192 4096 521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 8192 1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` After ``` 512 512 4.65 ms ± 62.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 5.75 ms ± 62.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 512 4096 30.3 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 39.4 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 512 8192 59.7 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 77 ms ± 601 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 512 32.2 ms ± 376 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 37.1 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 4096 204 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 272 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 4096 8192 422 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 562 ms ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 512 63.1 ms ± 595 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 72.7 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 8192 4096 401 ms ± 3.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 573 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 8192 831 ms ± 7.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.2 s ± 9.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/56821 Reviewed By: mrshenli Differential Revision: D28172609 Pulled By: ngimel fbshipit-source-id: 87314a6985a84d326304ff5220df5661ef00d710
Author
Parents
Loading