pytorch
d2784c23 - Partially migrate sort from THC to ATen, replace the thrust path with cub (#54626)

Commit
3 years ago
Partially migrate sort from THC to ATen, replace the thrust path with cub (#54626) Summary: The thrust path of `torch.sort` in THC is rewritten and replaced with cub in ATen. The original algorithm is followed, but since cub does not offer custom compare operator, I have to change it a bit to 2 sort + gather. Note: tensor larger than 2^31 elements is supported, but the dimension being sorted can not go beyond 2^31. Related: https://github.com/pytorch/pytorch/pull/50887 https://github.com/pytorch/pytorch/issues/24637 Benchmark: ```python import torch import itertools for i in range(1000): torch.arange(100000, device='cuda') def run50_sync(f): for _ in range(50): f() torch.cuda.synchronize() for i, j in itertools.product([512, 4096, 8192], repeat=2): print(i,j) t = torch.randn(i, j, device='cuda') torch.cuda.synchronize() %timeit run50_sync(lambda: torch.sort(t)) torch.cuda.synchronize() %timeit run50_sync(lambda: torch.sort(t, dim=0)) print() ``` Before ``` 512 512 3.91 ms ± 8.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 4.87 ms ± 5.06 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 512 4096 70.5 ms ± 29.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 32.7 ms ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 512 8192 142 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 64.4 ms ± 94.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 512 26.8 ms ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 82.2 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 4096 606 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 722 ms ± 94.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 4096 8192 1.28 s ± 157 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.54 s ± 500 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 512 53.5 ms ± 73.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 168 ms ± 39.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 8192 4096 1.28 s ± 236 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.54 s ± 272 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 8192 2.69 s ± 741 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 3.28 s ± 549 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` After ``` 512 512 4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 512 4096 40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 512 8192 71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 512 27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4096 4096 262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 4096 8192 520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 512 54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 8192 4096 521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 8192 8192 1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/54626 Reviewed By: VitalyFedyunin Differential Revision: D27396078 Pulled By: ngimel fbshipit-source-id: 4a23b9355e3542e49233b4b4328e43947ec17efd
Author
Parents
Loading