Migrate thrust->cub for index put (#55693)
Summary:
64bit indexing is not supported, because if `num_indices = 2^31`, then 4 long tensors of `num_indices` elements will take 64GB RAM. I don't think anybody will be interested in running `index_put` with 64GB GPU RAM.
Benchmark on CUDA 11.3 RTX3090:
```python
import torch
import itertools
def run50_sync(f):
for _ in range(50):
f()
torch.cuda.synchronize()
run50_sync(lambda: torch.randperm(1000000, device='cuda'))
def benchmark(M, L):
a = torch.randn(M, device='cuda')
i1 = torch.randint(M, (L,), dtype=torch.long, device='cuda')
v = torch.randn(L, device='cuda')
torch.cuda.synchronize()
%timeit run50_sync(lambda:a.index_put_((i1,), v, True))
for M, L in itertools.product((100, 100000, 10000000), repeat=2):
print(M, L)
benchmark(M, L)
```
Before
```
100 100
5.13 ms ± 91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
100 100000
30.2 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
100 10000000
3.17 s ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
100000 100
5.19 ms ± 61.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
100000 100000
11.9 ms ± 200 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
100000 10000000
712 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
10000000 100
5.07 ms ± 66.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10000000 100000
12.1 ms ± 76.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10000000 10000000
627 ms ± 7.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
After
```
100 100
3.75 ms ± 49.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
100 100000
26.2 ms ± 154 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
100 10000000
2.81 s ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
100000 100
3.85 ms ± 16.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
100000 100000
9.74 ms ± 40.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
100000 10000000
444 ms ± 1.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
10000000 100
3.85 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10000000 100000
10.7 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10000000 10000000
396 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55693
Reviewed By: albanD
Differential Revision: D27895967
Pulled By: ngimel
fbshipit-source-id: 0616ce33395ce46f1a4161dfd38940b8e54fedc2