pytorch
5152cf86 - masked_scatter thrust->cub (#56750)

Commit
3 years ago
masked_scatter thrust->cub (#56750) Summary: Benchmark: ```python import torch import itertools def run50_sync(f): for _ in range(50): f() torch.cuda.synchronize() run50_sync(lambda: torch.randperm(1000000, device='cuda')) def benchmark(M): a = torch.randn(M, device='cuda') m = torch.randint(1, (M,), dtype=torch.long, device='cuda').bool() v = torch.randn(M, device='cuda') torch.cuda.synchronize() %timeit run50_sync(lambda:a.masked_scatter_(m, v)) for M in (100, 1000, 100000, 10000000): print(M) benchmark(M) ``` Before: ``` 100 8.65 ms ± 80.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 1000 8.75 ms ± 72.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100000 9.27 ms ± 87.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10000000 33.6 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` After ``` 100 8.04 ms ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 1000 8.09 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 100000 8.63 ms ± 76.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10000000 31.9 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/56750 Reviewed By: ailzhang Differential Revision: D28547564 Pulled By: ngimel fbshipit-source-id: 83aeddfaf7023f9f9501c6b1e2faf91e8b6277b1
Author
Parents
Loading