Implement torch.sort with cub::DeviceSegmentedRadixSort (#56821)
Summary:
Benchmark:
```python
import torch
import itertools
for i in range(1000):
torch.arange(100000, device='cuda')
def run50_sync(f):
for _ in range(50):
f()
torch.cuda.synchronize()
for i, j in itertools.product([512, 4096, 8192], repeat=2):
print(i,j)
t = torch.randn(i, j, device='cuda')
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.sort(t))
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.sort(t, dim=0))
print()
```
Before
```
512 512
4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
512 4096
40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
512 8192
71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 512
27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 4096
262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4096 8192
520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 512
54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8192 4096
521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 8192
1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
After
```
512 512
4.65 ms ± 62.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.75 ms ± 62.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
512 4096
30.3 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.4 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
512 8192
59.7 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
77 ms ± 601 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 512
32.2 ms ± 376 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
37.1 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 4096
204 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
272 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4096 8192
422 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
562 ms ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 512
63.1 ms ± 595 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
72.7 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8192 4096
401 ms ± 3.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
573 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 8192
831 ms ± 7.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 s ± 9.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56821
Reviewed By: mrshenli
Differential Revision: D28172609
Pulled By: ngimel
fbshipit-source-id: 87314a6985a84d326304ff5220df5661ef00d710