Faster mul(sparse, sparse) with broadcasting in dense dims. (#85336)
This is a combo PR of https://github.com/pytorch/pytorch/pull/84929 and ~https://github.com/pytorch/pytorch/pull/83428~.
Preliminary benchmarks (square matrices of shape (n, n)).
<details>
<summary>Script</summary>
```python
import torch
import math
from IPython import get_ipython
from itertools import product, repeat
import pickle
from torch.utils.benchmark import Timer, Compare
torch.manual_seed(13)
problem_dims = (
# n > nnz
(10000, 100),
(100000, 1000),
(1000000, 10000),
# n < nnz
(10, 100),
(10, 1000),
(10, 10000),
(100, 1000),
(100, 10000),
(1000, 10000),
(1000, 100000),
(1000, 1000000),
#(1000000, 1000000000),
)
name = "PR"
device = "cuda"
results = []
for n, nnz in problem_dims:
def gen_tensor(coalesce=False):
shape = (n, n)
nrows, ncols = shape
rowidx = torch.randint(low=0, high=nrows, size=(nnz,), device=device)
colidx = torch.randint(low=0, high=ncols, size=(nnz,), device=device)
itemidx = torch.vstack((rowidx, colidx))
xvalues = torch.randn(nnz, device=device)
itemidx = torch.hstack((itemidx, itemidx))
xvalues = torch.hstack((xvalues, xvalues))
res = torch.sparse_coo_tensor(itemidx, xvalues, size=shape)
if coalesce:
return res.coalesce()
else:
return res
for x_coalesce, y_coalesce in product(*repeat((True, False), 2)):
x = gen_tensor(x_coalesce)
y = gen_tensor(y_coalesce)
smtp = "x * y"
timer = Timer(smtp,
globals=globals(),
label="coo.mul",
description=f"{name}: mul, device: {device}",
sub_label=f"n={n}, nnz={nnz}, coalesce=({x_coalesce, y_coalesce})",
num_threads=torch.get_num_threads())
results.append(timer.blocked_autorange())
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
with open(f"{name}_{device}_mul.pickle", 'wb') as f:
pickle.dump(results, f)
```
</details>
<details>
<summary>Gather results</summary>
```python
import pickle
from torch.utils.benchmark import Timer, Compare
files = [
"PR",
"master"
]
device = 'cuda'
timers = []
for name in files:
with open("{}_{}_mul.pickle".format(name, device), 'rb') as f:
timers += pickle.load(f)
compare = Compare(timers)
compare.trim_significant_figures()
compare.print()
```
</details>
<details>
<summary>CUDA</summary>
```
[------------------------------------------------- coo.mul -------------------------------------------------]
| PR: mul, device: cuda | master: mul, device: cuda
24 threads: -------------------------------------------------------------------------------------------------
n=10000, nnz=100, coalesce=((True, True)) | 95 | 91
n=10000, nnz=100, coalesce=((True, False)) | 87 | 242
n=10000, nnz=100, coalesce=((False, True)) | 87 | 226
n=10000, nnz=100, coalesce=((False, False)) | 130 | 371
n=100000, nnz=1000, coalesce=((True, True)) | 100 | 521
n=100000, nnz=1000, coalesce=((True, False)) | 90 | 649
n=100000, nnz=1000, coalesce=((False, True)) | 100 | 659
n=100000, nnz=1000, coalesce=((False, False)) | 200 | 781
n=1000000, nnz=10000, coalesce=((True, True)) | 100 | 4861
n=1000000, nnz=10000, coalesce=((True, False)) | 100 | 5012
n=1000000, nnz=10000, coalesce=((False, True)) | 98 | 5010
n=1000000, nnz=10000, coalesce=((False, False)) | 384 | 5174
n=10, nnz=100, coalesce=((True, True)) | 100 | 79
n=10, nnz=100, coalesce=((True, False)) | 100 | 221
n=10, nnz=100, coalesce=((False, True)) | 100 | 221
n=10, nnz=100, coalesce=((False, False)) | 100 | 350
n=10, nnz=1000, coalesce=((True, True)) | 100 | 100
n=10, nnz=1000, coalesce=((True, False)) | 100 | 240
n=10, nnz=1000, coalesce=((False, True)) | 100 | 254
n=10, nnz=1000, coalesce=((False, False)) | 100 | 392
n=10, nnz=10000, coalesce=((True, True)) | 100 | 110
n=10, nnz=10000, coalesce=((True, False)) | 110 | 286
n=10, nnz=10000, coalesce=((False, True)) | 110 | 286
n=10, nnz=10000, coalesce=((False, False)) | 271 | 455
n=100, nnz=1000, coalesce=((True, True)) | 110 | 851
n=100, nnz=1000, coalesce=((True, False)) | 110 | 1000
n=100, nnz=1000, coalesce=((False, True)) | 110 | 990
n=100, nnz=1000, coalesce=((False, False)) | 140 | 1124
n=100, nnz=10000, coalesce=((True, True)) | 110 | 5137
n=100, nnz=10000, coalesce=((True, False)) | 110 | 5391
n=100, nnz=10000, coalesce=((False, True)) | 100 | 5405
n=100, nnz=10000, coalesce=((False, False)) | 249 | 5539
n=1000, nnz=10000, coalesce=((True, True)) | 100 | 8598
n=1000, nnz=10000, coalesce=((True, False)) | 100 | 8800
n=1000, nnz=10000, coalesce=((False, True)) | 100 | 8782
n=1000, nnz=10000, coalesce=((False, False)) | 255 | 8956
n=1000, nnz=100000, coalesce=((True, True)) | 120 | 84500
n=1000, nnz=100000, coalesce=((True, False)) | 200 | 88560
n=1000, nnz=100000, coalesce=((False, True)) | 160 | 89000
n=1000, nnz=100000, coalesce=((False, False)) | 373 | 89000
n=1000, nnz=1000000, coalesce=((True, True)) | 312 | 606400
n=1000, nnz=1000000, coalesce=((True, False)) | 1340 | 609200
n=1000, nnz=1000000, coalesce=((False, True)) | 1340 | 609100
n=1000, nnz=1000000, coalesce=((False, False)) | 4408 | 611400
Times are in microseconds (us).
```
</details>
<details>
<summary>CPU</summary>
```
[------------------------------------------------ coo.mul ------------------------------------------------]
| PR: mul, device: cpu | master: mul, device: cpu
24 threads: -----------------------------------------------------------------------------------------------
n=10000, nnz=100, coalesce=((True, True)) | 8 | 8
n=10000, nnz=100, coalesce=((True, False)) | 32 | 34
n=10000, nnz=100, coalesce=((False, True)) | 32 | 34
n=10000, nnz=100, coalesce=((False, False)) | 41 | 56
n=100000, nnz=1000, coalesce=((True, True)) | 24 | 24
n=100000, nnz=1000, coalesce=((True, False)) | 90 | 100
n=100000, nnz=1000, coalesce=((False, True)) | 87 | 100
n=100000, nnz=1000, coalesce=((False, False)) | 231 | 255
n=1000000, nnz=10000, coalesce=((True, True)) | 190 | 200
n=1000000, nnz=10000, coalesce=((True, False)) | 908 | 2023
n=1000000, nnz=10000, coalesce=((False, True)) | 800 | 2036
n=1000000, nnz=10000, coalesce=((False, False)) | 3684 | 3989
n=10, nnz=100, coalesce=((True, True)) | 8 | 7
n=10, nnz=100, coalesce=((True, False)) | 34 | 30
n=10, nnz=100, coalesce=((False, True)) | 33 | 30
n=10, nnz=100, coalesce=((False, False)) | 44 | 50
n=10, nnz=1000, coalesce=((True, True)) | 8 | 7
n=10, nnz=1000, coalesce=((True, False)) | 100 | 100
n=10, nnz=1000, coalesce=((False, True)) | 130 | 100
n=10, nnz=1000, coalesce=((False, False)) | 746 | 210
n=10, nnz=10000, coalesce=((True, True)) | 8 | 7
n=10, nnz=10000, coalesce=((True, False)) | 1000 | 1500
n=10, nnz=10000, coalesce=((False, True)) | 1000 | 1510
n=10, nnz=10000, coalesce=((False, False)) | 3063 | 2457
n=100, nnz=1000, coalesce=((True, True)) | 25 | 25
n=100, nnz=1000, coalesce=((True, False)) | 180 | 130
n=100, nnz=1000, coalesce=((False, True)) | 200 | 130
n=100, nnz=1000, coalesce=((False, False)) | 271 | 255
n=100, nnz=10000, coalesce=((True, True)) | 100 | 100
n=100, nnz=10000, coalesce=((True, False)) | 2444 | 2290
n=100, nnz=10000, coalesce=((False, True)) | 2455 | 2357
n=100, nnz=10000, coalesce=((False, False)) | 5316 | 3783
n=1000, nnz=10000, coalesce=((True, True)) | 204 | 211
n=1000, nnz=10000, coalesce=((True, False)) | 2457 | 2480
n=1000, nnz=10000, coalesce=((False, True)) | 2448 | 2539
n=1000, nnz=10000, coalesce=((False, False)) | 3665 | 4801
n=1000, nnz=100000, coalesce=((True, True)) | 2293 | 2374
n=1000, nnz=100000, coalesce=((True, False)) | 9000 | 24620
n=1000, nnz=100000, coalesce=((False, True)) | 8000 | 25080
n=1000, nnz=100000, coalesce=((False, False)) | 26500 | 47650
n=1000, nnz=1000000, coalesce=((True, True)) | 10000 | 13000
n=1000, nnz=1000000, coalesce=((True, False)) | 80000 | 362200
n=1000, nnz=1000000, coalesce=((False, True)) | 78050 | 392600
n=1000, nnz=1000000, coalesce=((False, False)) | 312100 | 766900
Times are in microseconds (us).
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85336
Approved by: https://github.com/cpuhrsch