Sparse-sparse matrix multiplication (CPU/CUDA) (#39526)
Summary:
This PR implements matrix multiplication support for 2-d sparse tensors using the COO sparse format.
The current implementation of `torch.sparse.mm` support this configuration,
`torch.sparse.mm(sparse_matrix1, sparse_matrix2.to_dense())`, but this could spend a lot of memory when sparse_matrix2's shape is large.
This implementation extends `torch.sparse.mm` function to support `torch.sparse.mm(sparse_matrix1, sparse_matrix2)`
Resolves #[20988](https://github.com/pytorch/pytorch/issues/20988) for CPU/CUDA.
- [x] sparse matmul
- [x] CPU/CUDA C++ implementation
- [x] unittests
- [x] update torch.sparse.mm documentation
- [x] autograd support
The CPU sparse-sparse matmul was implemented taking as a reference this work "Sparse Matrix Multiplication Package (SMMP)". The GPU sparse-sparse matmul is based on cuSparse, there is specific code for CUSPARSE when CUSPARSE_VERSION >= 11 and old version of CUSPARSE. Both CPU/CUDA rely on the sparse-sparse matmul algorithm using the CSR indices format as it is one of the fastest algorithm.
Here it is the latest benchmark (script is here) results for torch.sparse.mm (CUDA) and torch.sparse.mm (CPU) and scipy, values are float32 scalars:
size | density | sparse.mm(CUDA) | sparse.mm(CPU) | scipy_coo_matmul
-- | -- | -- | -- | --
(32, 10000) | 0.01 | 822.7 | 79.4 | 704.1
(32, 10000) | 0.05 | 1741.1 | 402.6 | 1155.3
(32, 10000) | 0.1 | 2956.8 | 840.8 | 1885.4
(32, 10000) | 0.25 | 6417.7 | 2832.3 | 4665.2
(512, 10000) | 0.01 | 1010.2 | 3941.3 | 26937.7
(512, 10000) | 0.05 | 2216.2 | 26903.8 | 57343.7
(512, 10000) | 0.1 | 4868.4 | 87773.7 | 117477.0
(512, 10000) | 0.25 | 16639.3 | 608105.0 | 624290.4
(1024, 10000) | 0.01 | 1224.8 | 13088.1 | 110379.2
(1024, 10000) | 0.05 | 3897.5 | 94783.9 | 236541.8
(1024, 10000) | 0.1 | 10559.1 | 405312.5 | 525483.4
(1024, 10000) | 0.25 | 57456.3 | 2424337.5 | 2729318.7
A new backward algorithm was implemented using only `sparse @ sparse` and `sparse_mask` operations. Here is some benchmarking:
```
[------------------------- sparse.mm-backward -------------------------]
| sparse.backward | dense.backward
-----------------------------------------------------------------------
(32, 10000) | 0.01 | 13.5 | 2.4
(32, 10000) | 0.05 | 52.3 | 2.4
(512, 10000) | 0.01 | 1016.8 | 491.5
(512, 10000) | 0.05 | 1604.3 | 492.3
(1024, 10000) | 0.01 | 2384.1 | 1963.7
(1024, 10000) | 0.05 | 3965.8 | 1951.9
```
I added new benchmark tests. Now I am using a real dataset used in recent studies [1, 2] with different sparsity levels.
```
[---------------------------------- matmul ---------------------------------]
| 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98
1 threads: ------------------------------------------------------------------
(cpu) torch | 5.4 | 5.4 | 5.2 | 5.3 | 5.3 | 5.4
torch.sparse | 122.2 | 51.9 | 27.5 | 11.4 | 4.9 | 1.8
scipy | 150.1 | 87.4 | 69.2 | 56.8 | 38.4 | 17.1
(cuda) torch | 1.3 | 1.1 | 1.1 | 1.1 | 1.1 | 1.1
torch.sparse | 20.0 | 8.4 | 5.1 | 2.5 | 1.5 | 1.1
[----------------------------------- backward -----------------------------------]
| 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98
1 threads: -----------------------------------------------------------------------
(cpu) torch | 17.7 | 17.9 | 17.7 | 17.7 | 17.6 | 17.9
torch.sparse | 672.9 | 432.6 | 327.5 | 230.8 | 176.7 | 116.7
(cuda) torch | 3.8 | 3.6 | 3.5 | 3.5 | 3.6 | 3.5
torch.sparse | 68.8 | 46.2 | 35.6 | 24.2 | 17.8 | 11.9
Times are in milliseconds (ms).
```
In summary, I can say that the new `sparse @ sparse` backward algorithm is better as it is more about saving space than performance. Moreover, it is better than other options tested before.
## **References**
1. Trevor Gale, Matei Zaharia, Cliff Young, Erich Elsen. **Sparse GPU Kernels for Deep Learning.** Proceedings of the International Conference for High Performance Computing, 2020. [https://github.com/google-research/google-research/tree/master/sgk](https://github.com/google-research/google-research/tree/master/sgk)
2. Trevor Gale, Erich Elsen, Sara Hooker. **The State of Sparsity in Deep Neural Networks.** [https://github.com/google-research/google-research/tree/master/state_of_sparsity](https://github.com/google-research/google-research/tree/master/state_of_sparsity)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39526
Reviewed By: mruberry
Differential Revision: D25661239
Pulled By: ngimel
fbshipit-source-id: b515ecd66d25f347d637e159d51aa45fb43b6938