Partially migrate sort from THC to ATen, replace the thrust path with cub (#54626)
Summary:
The thrust path of `torch.sort` in THC is rewritten and replaced with cub in ATen. The original algorithm is followed, but since cub does not offer custom compare operator, I have to change it a bit to 2 sort + gather.
Note: tensor larger than 2^31 elements is supported, but the dimension being sorted can not go beyond 2^31.
Related: https://github.com/pytorch/pytorch/pull/50887 https://github.com/pytorch/pytorch/issues/24637
Benchmark:
```python
import torch
import itertools
for i in range(1000):
torch.arange(100000, device='cuda')
def run50_sync(f):
for _ in range(50):
f()
torch.cuda.synchronize()
for i, j in itertools.product([512, 4096, 8192], repeat=2):
print(i,j)
t = torch.randn(i, j, device='cuda')
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.sort(t))
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.sort(t, dim=0))
print()
```
Before
```
512 512
3.91 ms ± 8.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.87 ms ± 5.06 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
512 4096
70.5 ms ± 29.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
32.7 ms ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
512 8192
142 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
64.4 ms ± 94.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 512
26.8 ms ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
82.2 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 4096
606 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
722 ms ± 94.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
4096 8192
1.28 s ± 157 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.54 s ± 500 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 512
53.5 ms ± 73.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
168 ms ± 39.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8192 4096
1.28 s ± 236 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.54 s ± 272 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 8192
2.69 s ± 741 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.28 s ± 549 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
After
```
512 512
4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
512 4096
40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
512 8192
71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 512
27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 4096
262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4096 8192
520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 512
54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8192 4096
521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 8192
1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54626
Reviewed By: VitalyFedyunin
Differential Revision: D27396078
Pulled By: ngimel
fbshipit-source-id: 4a23b9355e3542e49233b4b4328e43947ec17efd