pytorch
ed0a572e - Migrate `scatter` and `scatter_` from the TH to Aten (CUDA) (#35697)

Commit
4 years ago
Migrate `scatter` and `scatter_` from the TH to Aten (CUDA) (#35697) Summary: Fixes [24621](https://github.com/pytorch/pytorch/issues/24621). Some preliminary results: ## Case 1: dense indexing ```python import torch import numpy from IPython import get_ipython numpy.random.seed(13) torch.manual_seed(13) ipython = get_ipython() Ms=1024 * 8 Ns=1024 * 4 dim = 0 top_power = 4 for pM in range(top_power): M = Ms * (2 ** pM) for pN in range(top_power): N = Ns * (2 ** pN) input_ = torch.rand(M, N, device=torch.device('cuda')) src = torch.rand(M, N, device=torch.device('cuda')) index = torch.tensor(numpy.random.randint(0, min(M, N), (M, N)), device=torch.device('cuda') ) print(f"Problem size (MxN): {M}x{N}") ipython.magic("timeit input_.scatter_(0, index, src); torch.cuda.synchronize()") ipython.magic("timeit input_.scatter_(1, index, src); torch.cuda.synchronize()") ``` ### TH ``` Problem size (MxN): 8192x4096 11.5 ms ± 4.52 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 1.21 ms ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Problem size (MxN): 8192x8192 24.1 ms ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 2.49 ms ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 8192x16384 48.5 ms ± 4.33 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 5.3 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 8192x32768 97.5 ms ± 3.82 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 12.2 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x4096 22.9 ms ± 1.96 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 2.43 ms ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x8192 48.2 ms ± 3.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 5.03 ms ± 13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x16384 97.6 ms ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 10.2 ms ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x32768 196 ms ± 8.61 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 20.2 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Problem size (MxN): 32768x4096 45.8 ms ± 4.11 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4.85 ms ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 32768x8192 96.4 ms ± 3.98 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 10 ms ± 6.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 32768x16384 195 ms ± 7.16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 20.3 ms ± 161 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Problem size (MxN): 32768x32768 391 ms ± 36.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 40.7 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Problem size (MxN): 65536x4096 91.5 ms ± 5.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 9.65 ms ± 3.93 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 65536x8192 192 ms ± 9.94 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 20.1 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 65536x16384 390 ms ± 26.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 40.7 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Problem size (MxN): 65536x32768 783 ms ± 33.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 86.9 ms ± 76.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` ### ATen ``` Problem size (MxN): 8192x4096 [49/1095] 12 ms ± 3.71 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 1.19 ms ± 236 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Problem size (MxN): 8192x8192 25.1 ms ± 3.91 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 2.38 ms ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 8192x16384 50.6 ms ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4.62 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 8192x32768 102 ms ± 5.16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 9.26 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x4096 23.9 ms ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 2.37 ms ± 7.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x8192 50.2 ms ± 3.08 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4.83 ms ± 8.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x16384 102 ms ± 3.97 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 9.79 ms ± 6.48 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 16384x32768 204 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 19.1 ms ± 32.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 32768x4096 47.8 ms ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 4.72 ms ± 8.92 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 32768x8192 100 ms ± 4.53 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 9.68 ms ± 7.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 32768x16384 203 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 19.6 ms ± 17.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 32768x32768 408 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 39.4 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Problem size (MxN): 65536x4096 95.6 ms ± 3.77 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 9.41 ms ± 735 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 65536x8192 201 ms ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 19.3 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Problem size (MxN): 65536x16384 407 ms ± 16.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 39.2 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Problem size (MxN): 65536x32768 816 ms ± 40.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 78.4 ms ± 45.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` ## Case 2: sparse indexing ```python import torch from IPython import get_ipython ipython = get_ipython() torch.set_num_threads(1) device = 'cuda' nrows = 100000 ncols = 100000 dims = [nrows, ncols] res = torch.randn(dims, device=device) idx1 = torch.randint(dims[0], (1, dims[1]), device=device).long() src1 = torch.randn(1, dims[1], device=device) idx2 = torch.randint(dims[1], (dims[0], 1), device=device).long() src2 = torch.randn(dims[0], 1, device=device) ipython.magic("timeit res.scatter_(0, idx1, src1); torch.cuda.synchronize()") ipython.magic("timeit res.scatter_(1, idx2, src2); torch.cuda.synchronize()") ``` ### TH ``` 199 µs ± 609 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 43.3 µs ± 95.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` ### ATen ``` 199 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 119 µs ± 3.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` ## Case 3: many-to-one, one-to-many ```python import torch from IPython import get_ipython ipython = get_ipython() torch.set_num_threads(1) device = 'cuda' nfeat = 10000 nrep = 5 a=torch.arange(nfeat, device=device).repeat_interleave(nrep) batch=3 #batch can vary 1-200 res = torch.randn(100000, 100000, device=device) for batch in [100, 500, 1000, 5000, 10000]: print("Batch: ", batch) c=torch.randint(3, (batch, nfeat * nrep), device=device).float() ipython.magic("timeit res.scatter_(1,a.unsqueeze(0).expand(batch,a.size(0)),c); torch.cuda.synchronize()") enum_values = [ list(range(1, 201)), list(range(1000, 1020)), list(range(2000, 2010)), list(range(3000, 3206)), ] indices = torch.tensor([i for i, values in enumerate(enum_values) for _j in range(len(values))], device=device) c = torch.randint(3, (batch, 4), device=device).float() idx = indices.unsqueeze(0).expand(c.size(0), indices.size(0)) src = c.repeat(1, idx.shape[-1] // c.shape[-1]) ipython.magic("timeit res.scatter_(1,idx,src); torch.cuda.synchronize()") print() ``` ### TH ``` Batch: 100 119 µs ± 287 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 14.7 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) Batch: 500 534 µs ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 16.4 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) Batch: 1000 1.06 ms ± 2.96 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 20.6 µs ± 53.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) Batch: 5000 5.28 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 56.3 µs ± 93.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) Batch: 10000 10.6 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 101 µs ± 148 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` ### ATen ``` Batch: 100 63.9 µs ± 501 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 13.5 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) Batch: 500 241 µs ± 535 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 14.8 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) Batch: 1000 468 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 16.7 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) Batch: 5000 2.27 ms ± 5.59 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 31.1 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) Batch: 10000 4.52 ms ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 54 µs ± 82.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` ## Correctness (passed) ```python import torch import numpy from IPython import get_ipython numpy.random.seed(13) torch.manual_seed(13) ipython = get_ipython() Ms=1024 * 2 Ns=1024 * 2 dim = 0 top_power = 5 for pM in range(top_power): M = Ms * (2 ** pM) for pN in range(top_power): N = Ns * (2 ** pN) input_ = torch.rand(M, N, device=torch.device('cuda')) input_clone_ = input_.clone() #src = torch.rand(M, N, device=torch.device('cuda')) src = torch.ones(M, N, device=torch.device('cuda')) index = torch.tensor(numpy.random.randint(0, min(M, N), (M, N)), device=torch.device('cuda') ) other_index1 = torch.arange(0, N, device=torch.device('cuda')).repeat(M, 1) other_index0 = torch.arange(0, M, device=torch.device('cuda')).repeat(N, 1).t() print(f"Problem size (MxN): {M}x{N}") #ipython.magic("timeit input_.scatter_(0, index, src); torch.cuda.synchronize()") #ipython.magic("timeit input_.scatter_(1, index, src); torch.cuda.synchronize()") input_.scatter_(0, index, src) input_clone_.index_put_((index, other_index1), src); assert((input_ == input_clone_).all()) input_.scatter_(1, index, src) input_clone_.index_put_((other_index0, index), src); assert((input_ == input_clone_).all()) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/35697 Differential Revision: D21258380 Pulled By: ngimel fbshipit-source-id: aebf01474cc9caf0a1dc1041ca6b753e3981df2e
Author
Parents
Loading