Faster `index_select` for sparse COO tensors on CPU. (#72710)
Fixes https://github.com/pytorch/pytorch/issues/72212.
This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.
Benchmark results.
<details>
<summary>Testing script</summary>
```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare
torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()
index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
# n > nnz
(10000, 100),
(100000, 1000),
(1000000, 10000),
# n < nnz
(10, 100),
(10, 1000),
(10, 10000),
(100, 1000),
(100, 10000),
(1000, 10000),
(1000, 100000),
(1000, 1000000),
#(1000000, 1000000000),
)
def f(t, d, index):
s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
ss = s.index_select(d, index)
return ss.coo()
name = "PR"
results = []
for (n, nnz), m in product(problem_dims, index_sizes):
for d in (0, 1):
if nnz < n:
shape = (n, n)
else:
shape = (n, nnz // n) if d == 0 else (nnz // n, n)
nrows, ncols = shape
rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
colidx = torch.randint(low=0, high=ncols, size=(nnz,))
itemidx = torch.vstack((rowidx, colidx))
xvalues = torch.randn(nnz)
index = torch.randint(low=0, high=n, size=(m,))
SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
smtp = "SparseX.index_select(d, index)"
timer = Timer(smtp,
globals=globals(),
label="coo.index_select",
description=f"{name}: coo.index_select",
sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
num_threads=torch.get_num_threads())
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{name}_index_select.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
<details>
<summary>Gather results</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"PR",
"torch_sparse",
"master"
]
timers = []
for name in files:
with open("{}_index_select.pickle".format(name), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
<details>
<summary>PR/torch_sparse/master runtime comparison</summary>
```
[----------------------------------- coo.index_select ----------------------------------]
| PR | torch_sparse | master
32 threads: -----------------------------------------------------------------------------
n=10000, nnz=100, index_len=100, dim=0 | 14 | 140 | 10
n=10000, nnz=100, index_len=100, dim=1 | 14 | 200 | 10
n=10000, nnz=100, index_len=1000, dim=0 | 30 | 180 | 38
n=10000, nnz=100, index_len=1000, dim=1 | 34 | 240 | 38
n=10000, nnz=100, index_len=10000, dim=0 | 278 | 460 | 330
n=10000, nnz=100, index_len=10000, dim=1 | 275 | 516 | 330
n=100000, nnz=1000, index_len=100, dim=0 | 16 | 290 | 31
n=100000, nnz=1000, index_len=100, dim=1 | 26 | 390 | 31
n=100000, nnz=1000, index_len=1000, dim=0 | 45 | 405 | 263
n=100000, nnz=1000, index_len=1000, dim=1 | 73 | 500 | 261
n=100000, nnz=1000, index_len=10000, dim=0 | 444 | 783 | 2570
n=100000, nnz=1000, index_len=10000, dim=1 | 470 | 890 | 2590
n=1000000, nnz=10000, index_len=100, dim=0 | 25 | 2400 | 270
n=1000000, nnz=10000, index_len=100, dim=1 | 270 | 4000 | 269
n=1000000, nnz=10000, index_len=1000, dim=0 | 74 | 2600 | 2620
n=1000000, nnz=10000, index_len=1000, dim=1 | 464 | 3600 | 2640
n=1000000, nnz=10000, index_len=10000, dim=0 | 635 | 3300 | 26400
n=1000000, nnz=10000, index_len=10000, dim=1 | 1000 | 3960 | 26400
n=10, nnz=100, index_len=100, dim=0 | 16 | 137 | 16
n=10, nnz=100, index_len=100, dim=1 | 16 | 220 | 16
n=10, nnz=100, index_len=1000, dim=0 | 63 | 238 | 81
n=10, nnz=100, index_len=1000, dim=1 | 60 | 698 | 78
n=10, nnz=100, index_len=10000, dim=0 | 480 | 940 | 862
n=10, nnz=100, index_len=10000, dim=1 | 330 | 4930 | 1070
n=10, nnz=1000, index_len=100, dim=0 | 60 | 200 | 73
n=10, nnz=1000, index_len=100, dim=1 | 56 | 683 | 70
n=10, nnz=1000, index_len=1000, dim=0 | 480 | 530 | 1050
n=10, nnz=1000, index_len=1000, dim=1 | 330 | 4550 | 1368
n=10, nnz=1000, index_len=10000, dim=0 | 3100 | 2900 | 9300
n=10, nnz=1000, index_len=10000, dim=1 | 3400 | 46000 | 9100
n=10, nnz=10000, index_len=100, dim=0 | 400 | 453 | 857
n=10, nnz=10000, index_len=100, dim=1 | 400 | 4070 | 1730
n=10, nnz=10000, index_len=1000, dim=0 | 2840 | 2600 | 13900
n=10, nnz=10000, index_len=1000, dim=1 | 3700 | 40600 | 16000
n=10, nnz=10000, index_len=10000, dim=0 | 83200 | 67400 | 160000
n=10, nnz=10000, index_len=10000, dim=1 | 68000 | 528000 | 190000
n=100, nnz=1000, index_len=100, dim=0 | 46 | 148 | 31
n=100, nnz=1000, index_len=100, dim=1 | 45 | 242 | 37
n=100, nnz=1000, index_len=1000, dim=0 | 68 | 248 | 240
n=100, nnz=1000, index_len=1000, dim=1 | 66 | 755 | 290
n=100, nnz=1000, index_len=10000, dim=0 | 370 | 802 | 2250
n=100, nnz=1000, index_len=10000, dim=1 | 372 | 5430 | 2770
n=100, nnz=10000, index_len=100, dim=0 | 82 | 210 | 224
n=100, nnz=10000, index_len=100, dim=1 | 74 | 986 | 270
n=100, nnz=10000, index_len=1000, dim=0 | 350 | 618 | 2600
n=100, nnz=10000, index_len=1000, dim=1 | 370 | 4660 | 4560
n=100, nnz=10000, index_len=10000, dim=0 | 3000 | 3400 | 41680
n=100, nnz=10000, index_len=10000, dim=1 | 5000 | 47500 | 30400
n=1000, nnz=10000, index_len=100, dim=0 | 71 | 160 | 185
n=1000, nnz=10000, index_len=100, dim=1 | 64 | 516 | 190
n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 249 | 1740
n=1000, nnz=10000, index_len=1000, dim=1 | 98 | 1030 | 1770
n=1000, nnz=10000, index_len=10000, dim=0 | 600 | 808 | 18300
n=1000, nnz=10000, index_len=10000, dim=1 | 663 | 5300 | 18500
n=1000, nnz=100000, index_len=100, dim=0 | 160 | 258 | 1890
n=1000, nnz=100000, index_len=100, dim=1 | 200 | 3620 | 2050
n=1000, nnz=100000, index_len=1000, dim=0 | 500 | 580 | 18700
n=1000, nnz=100000, index_len=1000, dim=1 | 640 | 7550 | 30000
n=1000, nnz=100000, index_len=10000, dim=0 | 3400 | 3260 | 186000
n=1000, nnz=100000, index_len=10000, dim=1 | 3600 | 49600 | 194000
n=1000, nnz=1000000, index_len=100, dim=0 | 517 | 957 | 18700
n=1000, nnz=1000000, index_len=100, dim=1 | 680 | 39600 | 37600
n=1000, nnz=1000000, index_len=1000, dim=0 | 3600 | 4500 | 186000
n=1000, nnz=1000000, index_len=1000, dim=1 | 5800 | 76400 | 190000
n=1000, nnz=1000000, index_len=10000, dim=0 | 50000 | 67900 | 1800000
n=1000, nnz=1000000, index_len=10000, dim=1 | 45000 | 570000 | 1900000
Times are in microseconds (us).
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72710
Approved by: https://github.com/pearu, https://github.com/cpuhrsch