pytorch
3de86b95 - Migrate thrust->cub for index put (#55693)

Commit
3 years ago
Migrate thrust->cub for index put (#55693) Summary: 64bit indexing is not supported, because if `num_indices = 2^31`, then 4 long tensors of `num_indices` elements will take 64GB RAM. I don't think anybody will be interested in running `index_put` with 64GB GPU RAM. Benchmark on CUDA 11.3 RTX3090: ```python import torch import itertools def run50_sync(f): for _ in range(50): f() torch.cuda.synchronize() run50_sync(lambda: torch.randperm(1000000, device='cuda')) def benchmark(M, L): a = torch.randn(M, device='cuda') i1 = torch.randint(M, (L,), dtype=torch.long, device='cuda') v = torch.randn(L, device='cuda') torch.cuda.synchronize() %timeit run50_sync(lambda:a.index_put_((i1,), v, True)) for M, L in itertools.product((100, 100000, 10000000), repeat=2): print(M, L) benchmark(M, L) ``` Before ``` 100 100 5.13 ms ± 91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100 100000 30.2 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 100 10000000 3.17 s ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 100000 100 5.19 ms ± 61.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100000 100000 11.9 ms ± 200 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100000 10000000 712 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 10000000 100 5.07 ms ± 66.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10000000 100000 12.1 ms ± 76.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10000000 10000000 627 ms ± 7.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` After ``` 100 100 3.75 ms ± 49.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100 100000 26.2 ms ± 154 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 100 10000000 2.81 s ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 100000 100 3.85 ms ± 16.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100000 100000 9.74 ms ± 40.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100000 10000000 444 ms ± 1.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 10000000 100 3.85 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10000000 100000 10.7 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10000000 10000000 396 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/55693 Reviewed By: albanD Differential Revision: D27895967 Pulled By: ngimel fbshipit-source-id: 0616ce33395ce46f1a4161dfd38940b8e54fedc2
Author
Parents
Loading