Migrate `scatter` and `scatter_` from the TH to Aten (CUDA) (#35697)
Summary:
Fixes [24621](https://github.com/pytorch/pytorch/issues/24621).
Some preliminary results:
## Case 1: dense indexing
```python
import torch
import numpy
from IPython import get_ipython
numpy.random.seed(13)
torch.manual_seed(13)
ipython = get_ipython()
Ms=1024 * 8
Ns=1024 * 4
dim = 0
top_power = 4
for pM in range(top_power):
M = Ms * (2 ** pM)
for pN in range(top_power):
N = Ns * (2 ** pN)
input_ = torch.rand(M, N, device=torch.device('cuda'))
src = torch.rand(M, N, device=torch.device('cuda'))
index = torch.tensor(numpy.random.randint(0, min(M, N), (M, N)), device=torch.device('cuda') )
print(f"Problem size (MxN): {M}x{N}")
ipython.magic("timeit input_.scatter_(0, index, src); torch.cuda.synchronize()")
ipython.magic("timeit input_.scatter_(1, index, src); torch.cuda.synchronize()")
```
### TH
```
Problem size (MxN): 8192x4096
11.5 ms ± 4.52 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.21 ms ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Problem size (MxN): 8192x8192
24.1 ms ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.49 ms ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 8192x16384
48.5 ms ± 4.33 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
5.3 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 8192x32768
97.5 ms ± 3.82 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
12.2 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x4096
22.9 ms ± 1.96 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.43 ms ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x8192
48.2 ms ± 3.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
5.03 ms ± 13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x16384
97.6 ms ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
10.2 ms ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x32768
196 ms ± 8.61 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
20.2 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Problem size (MxN): 32768x4096
45.8 ms ± 4.11 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.85 ms ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 32768x8192
96.4 ms ± 3.98 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
10 ms ± 6.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 32768x16384
195 ms ± 7.16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
20.3 ms ± 161 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Problem size (MxN): 32768x32768
391 ms ± 36.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
40.7 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Problem size (MxN): 65536x4096
91.5 ms ± 5.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.65 ms ± 3.93 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 65536x8192
192 ms ± 9.94 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
20.1 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 65536x16384
390 ms ± 26.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
40.7 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Problem size (MxN): 65536x32768
783 ms ± 33.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
86.9 ms ± 76.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
### ATen
```
Problem size (MxN): 8192x4096 [49/1095]
12 ms ± 3.71 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.19 ms ± 236 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Problem size (MxN): 8192x8192
25.1 ms ± 3.91 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.38 ms ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 8192x16384
50.6 ms ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.62 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 8192x32768
102 ms ± 5.16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.26 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x4096
23.9 ms ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.37 ms ± 7.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x8192
50.2 ms ± 3.08 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.83 ms ± 8.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x16384
102 ms ± 3.97 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.79 ms ± 6.48 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 16384x32768
204 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
19.1 ms ± 32.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 32768x4096
47.8 ms ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.72 ms ± 8.92 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 32768x8192
100 ms ± 4.53 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.68 ms ± 7.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 32768x16384
203 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
19.6 ms ± 17.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 32768x32768
408 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
39.4 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Problem size (MxN): 65536x4096
95.6 ms ± 3.77 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.41 ms ± 735 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 65536x8192
201 ms ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
19.3 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Problem size (MxN): 65536x16384
407 ms ± 16.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
39.2 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Problem size (MxN): 65536x32768
816 ms ± 40.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
78.4 ms ± 45.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
## Case 2: sparse indexing
```python
import torch
from IPython import get_ipython
ipython = get_ipython()
torch.set_num_threads(1)
device = 'cuda'
nrows = 100000
ncols = 100000
dims = [nrows, ncols]
res = torch.randn(dims, device=device)
idx1 = torch.randint(dims[0], (1, dims[1]), device=device).long()
src1 = torch.randn(1, dims[1], device=device)
idx2 = torch.randint(dims[1], (dims[0], 1), device=device).long()
src2 = torch.randn(dims[0], 1, device=device)
ipython.magic("timeit res.scatter_(0, idx1, src1); torch.cuda.synchronize()")
ipython.magic("timeit res.scatter_(1, idx2, src2); torch.cuda.synchronize()")
```
### TH
```
199 µs ± 609 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
43.3 µs ± 95.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
### ATen
```
199 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
119 µs ± 3.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
## Case 3: many-to-one, one-to-many
```python
import torch
from IPython import get_ipython
ipython = get_ipython()
torch.set_num_threads(1)
device = 'cuda'
nfeat = 10000
nrep = 5
a=torch.arange(nfeat, device=device).repeat_interleave(nrep)
batch=3 #batch can vary 1-200
res = torch.randn(100000, 100000, device=device)
for batch in [100, 500, 1000, 5000, 10000]:
print("Batch: ", batch)
c=torch.randint(3, (batch, nfeat * nrep), device=device).float()
ipython.magic("timeit res.scatter_(1,a.unsqueeze(0).expand(batch,a.size(0)),c); torch.cuda.synchronize()")
enum_values = [
list(range(1, 201)),
list(range(1000, 1020)),
list(range(2000, 2010)),
list(range(3000, 3206)),
]
indices = torch.tensor([i for i, values in enumerate(enum_values) for _j in range(len(values))], device=device)
c = torch.randint(3, (batch, 4), device=device).float()
idx = indices.unsqueeze(0).expand(c.size(0), indices.size(0))
src = c.repeat(1, idx.shape[-1] // c.shape[-1])
ipython.magic("timeit res.scatter_(1,idx,src); torch.cuda.synchronize()")
print()
```
### TH
```
Batch: 100
119 µs ± 287 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
14.7 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Batch: 500
534 µs ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
16.4 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Batch: 1000
1.06 ms ± 2.96 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
20.6 µs ± 53.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Batch: 5000
5.28 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
56.3 µs ± 93.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Batch: 10000
10.6 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
101 µs ± 148 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
### ATen
```
Batch: 100
63.9 µs ± 501 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
13.5 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Batch: 500
241 µs ± 535 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
14.8 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Batch: 1000
468 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
16.7 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Batch: 5000
2.27 ms ± 5.59 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
31.1 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Batch: 10000
4.52 ms ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
54 µs ± 82.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
## Correctness (passed)
```python
import torch
import numpy
from IPython import get_ipython
numpy.random.seed(13)
torch.manual_seed(13)
ipython = get_ipython()
Ms=1024 * 2
Ns=1024 * 2
dim = 0
top_power = 5
for pM in range(top_power):
M = Ms * (2 ** pM)
for pN in range(top_power):
N = Ns * (2 ** pN)
input_ = torch.rand(M, N, device=torch.device('cuda'))
input_clone_ = input_.clone()
#src = torch.rand(M, N, device=torch.device('cuda'))
src = torch.ones(M, N, device=torch.device('cuda'))
index = torch.tensor(numpy.random.randint(0, min(M, N), (M, N)), device=torch.device('cuda') )
other_index1 = torch.arange(0, N, device=torch.device('cuda')).repeat(M, 1)
other_index0 = torch.arange(0, M, device=torch.device('cuda')).repeat(N, 1).t()
print(f"Problem size (MxN): {M}x{N}")
#ipython.magic("timeit input_.scatter_(0, index, src); torch.cuda.synchronize()")
#ipython.magic("timeit input_.scatter_(1, index, src); torch.cuda.synchronize()")
input_.scatter_(0, index, src)
input_clone_.index_put_((index, other_index1), src);
assert((input_ == input_clone_).all())
input_.scatter_(1, index, src)
input_clone_.index_put_((other_index0, index), src);
assert((input_ == input_clone_).all())
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35697
Differential Revision: D21258380
Pulled By: ngimel
fbshipit-source-id: aebf01474cc9caf0a1dc1041ca6b753e3981df2e