Sparse CSR CPU: cuSolverSP backend for `linalg.solve` (#71399)
Summary:
This PR introduces the `cuSolverSP` backend for `linalg.solve` with sparse CSR input matrices. The motivation comes from the issue: https://github.com/pytorch/pytorch/issues/69538.
`cuSolver` provides [`cusolverSp<t>csrlsvluHost`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu) API, a few things to note:
1. As mentioned in the documentation: `only CPU (Host) path is provided.` From the profiling, there doesn't seem to be any GPU kernel launch for optimization, please see the profiling below.
2. Since only `host` path is provided, the CPU path uses `csrlsvluHost` (but requires PyTorch to be installed/built with CUDA support).
3. The documentation mentions reordering helps optimize stuff, but it isn't clear how it affects the performance. There are options for reordering, so we stick to `reorder = 0` as the default choice.
`cuSolver` has [`csrlsvqr`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr) function which provides a `device` path to solve the linear system. This function is used for the CUDA path in this PR.
**Gist:**
For CPU Path: we call [`csrlsvluHost` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu).
For CUDA Path: we call [`csrlsvqr` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr).
**Profiling:** (On sparse input tensor of size 1000 x 1000, with a vector of shape length 1000), for `csrlsvlu` function (to show no GPU optimization)
```cpp
==3999651== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 100.00% 2.1440us 1 2.1440us 2.1440us 2.1440us [CUDA memcpy HtoD]
API calls: 99.72% 1.07199s 9 119.11ms 500ns 1.07164s cudaFree
0.11% 1.2182ms 398 3.0600us 140ns 137.94us cuDeviceGetAttribute
0.06% 674.45us 4 168.61us 165.50us 173.64us cuDeviceTotalMem
0.03% 357.07us 4 89.268us 2.7800us 201.89us cudaMalloc
0.03% 309.29us 1 309.29us 309.29us 309.29us cudaGetDeviceProperties
0.01% 160.47us 332 483ns 350ns 3.3300us cudaFuncSetAttribute
0.01% 115.12us 4 28.780us 26.290us 33.410us cuDeviceGetName
0.00% 28.591us 5 5.7180us 440ns 16.921us cudaGetDevice
0.00% 22.061us 4 5.5150us 871ns 18.690us cudaDeviceSynchronize
0.00% 20.370us 18 1.1310us 410ns 6.9900us cudaEventDestroy
0.00% 16.390us 1 16.390us 16.390us 16.390us cudaMemcpy
0.00% 11.540us 2 5.7700us 1.4900us 10.050us cuDeviceGetPCIBusId
0.00% 10.510us 18 583ns 430ns 1.6200us cudaEventCreateWithFlags
0.00% 7.9100us 21 376ns 290ns 700ns cudaDeviceGetAttribute
0.00% 1.4300us 6 238ns 150ns 590ns cuDeviceGet
0.00% 1.2200us 4 305ns 190ns 500ns cuDeviceGetCount
0.00% 900ns 1 900ns 900ns 900ns cuInit
0.00% 860ns 4 215ns 180ns 260ns cuDeviceGetUuid
0.00% 240ns 1 240ns 240ns 240ns cuDriverGetVersion
0.00% 230ns 1 230ns 230ns 230ns cudaGetDeviceCount
```
Script:
```python
import torch
def solve(x, other, out):
torch.linalg.solve(x, other, out=out)
if __name__ == "__main__":
dense_inp = torch.randn((1000, 1000), dtype=torch.float64)
# Set 50% of the values to 0 randomly
dense_inp = torch.nn.functional.dropout(dense_inp, p=0.5)
sparse_inp = dense_inp.to_sparse_csr()
other = torch.randint(100, (1000,), dtype=torch.float64)
out = torch.randint(1, (1000,), dtype=torch.float64)
solve(sparse_inp, other, out)
```
The following error is raised when the function is used on a CPU device with PyTorch built/installed without CUDA support:
* When built without CUDA support:
```python
/home/krshrimali/pytorch/torch/autograd/profiler.py:151: UserWarning: CUDA is not available, disabling CUDA profiling
warn("CUDA is not available, disabling CUDA profiling")
Traceback (most recent call last):
File "/home/krshrimali/pytorch/test_sp.py", line 17, in <module>
solve(x, other, out)
File "/home/krshrimali/pytorch/test_sp.py", line 5, in solve
torch.linalg.solve(x, other, out=out)
RuntimeError: PyTorch was not built with CUDA support. Please use PyTorch built CUDA support
```
**Performance Comparison** (vs SciPy's [`scipy.sparse.linalg.spsolve`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html):
Time taken by `scipy.sparse.linalg.spsolve` : 0.595 seconds
On CPU: Time taken by `torch.linalg.solve` : 4.565 seconds
On CUDA: Time taken by `torch.linalg.solve`: 1.838 seconds
The inputs are of dimensions: (17281, 17281) and (17281, 1), and were taken from https://math.nist.gov/MatrixMarket/extreme.html.
Thanks to IvanYashchuk for helping me with the PR, and guiding me through it.
cc: IvanYashchuk pearu nikitaved cpuhrsch
cc nikitaved pearu cpuhrsch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71399
Reviewed By: VitalyFedyunin
Differential Revision: D33767740
Pulled By: cpuhrsch
fbshipit-source-id: a945f065210cd719096eb8d7cdbf8e8937c2fce9
(cherry picked from commit f4f35c17da414e1ca6c6d91402933521857aa1ea)