LU Solve using cublas and cusolver (#59148)
Summary:
This PR introduces cuSOLVER and cuBLAS for the `lu_solve` routine. Solves a part of https://github.com/pytorch/pytorch/issues/47953.
Since usage of cuSOLVER with MAGMA introduces performance regressions in MAGMA (https://github.com/pytorch/pytorch/issues/56590), we use heuristics for determining when to call cuSOLVER, cuBLAS or MAGMA depending on the batch and matrix sizes. The 64-bit cuSOLVER API is not introduced in this PR since there are several problems with the LU factorization using cusolver (https://github.com/pytorch/pytorch/pull/59148).
The following are performance benchmarks using various configurations:
<details>
```
[--------------------------------------------------------- LU solve CUDA torch.float64 ----------------------------------------------------------]
| lu_solve CUSOLVER | lu_solve MAGMA | lu_solve CUBLAS | lu_solve cuSOLVER/MAGMA | lu_solve TEST ALL
1 threads: ---------------------------------------------------------------------------------------------------------------------------------------
torch.Size([1, 1, 1]) | 703.4 | 489.8 | 511.8 | 710.1 | 487.1
torch.Size([2, 1, 1]) | 738.9 | 504.1 | 513.0 | 958.2 | 494.4
torch.Size([4, 1, 1]) | 790.7 | 514.7 | 506.8 | 983.9 | 540.2
torch.Size([8, 1, 1]) | 865.3 | 496.4 | 514.7 | 975.2 | 520.0
torch.Size([16, 1, 1]) | 955.5 | 483.9 | 508.3 | 937.6 | 526.5
torch.Size([32, 1, 1]) | 1167.7 | 495.2 | 511.2 | 934.0 | 528.7
torch.Size([64, 1, 1]) | 1730.0 | 492.1 | 537.8 | 936.4 | 533.2
torch.Size([128, 1, 1]) | 2748.4 | 499.7 | 526.5 | 982.9 | 540.8
torch.Size([1, 2, 2]) | 724.6 | 498.2 | 541.7 | 715.0 | 504.7
torch.Size([2, 2, 2]) | 737.0 | 514.3 | 527.6 | 934.5 | 524.5
torch.Size([4, 2, 2]) | 750.5 | 524.1 | 537.4 | 935.5 | 543.0
torch.Size([8, 2, 2]) | 844.8 | 513.7 | 538.9 | 953.3 | 534.4
torch.Size([16, 2, 2]) | 1013.1 | 521.9 | 530.0 | 932.2 | 537.9
torch.Size([32, 2, 2]) | 1335.8 | 515.1 | 544.4 | 939.9 | 559.5
torch.Size([64, 2, 2]) | 1819.6 | 511.8 | 534.1 | 973.9 | 540.0
torch.Size([128, 2, 2]) | 3018.7 | 526.3 | 546.1 | 979.3 | 543.5
torch.Size([1, 8, 8]) | 732.5 | 524.9 | 532.9 | 762.4 | 516.8
torch.Size([2, 8, 8]) | 771.2 | 514.9 | 538.7 | 1007.5 | 531.1
torch.Size([4, 8, 8]) | 811.3 | 507.7 | 534.6 | 1002.2 | 548.5
torch.Size([8, 8, 8]) | 866.6 | 530.0 | 532.0 | 1016.1 | 562.9
torch.Size([16, 8, 8]) | 991.8 | 533.6 | 548.0 | 1022.6 | 548.5
torch.Size([32, 8, 8]) | 1271.7 | 541.2 | 534.7 | 1013.8 | 545.6
torch.Size([64, 8, 8]) | 1817.2 | 530.2 | 520.6 | 1008.7 | 566.3
torch.Size([128, 8, 8]) | 2678.7 | 531.6 | 552.2 | 1006.2 | 555.0
torch.Size([1, 16, 16]) | 738.2 | 546.1 | 536.6 | 775.6 | 540.1
torch.Size([2, 16, 16]) | 782.6 | 543.5 | 539.6 | 1010.9 | 541.1
torch.Size([4, 16, 16]) | 815.2 | 546.1 | 560.9 | 1012.5 | 553.1
torch.Size([8, 16, 16]) | 877.7 | 543.0 | 547.9 | 1012.8 | 551.5
torch.Size([16, 16, 16]) | 1008.7 | 549.2 | 562.7 | 1016.6 | 546.8
torch.Size([32, 16, 16]) | 1291.9 | 540.8 | 560.3 | 1055.8 | 539.3
torch.Size([64, 16, 16]) | 1846.3 | 553.5 | 556.0 | 1010.8 | 551.9
torch.Size([128, 16, 16]) | 2953.8 | 562.7 | 547.5 | 1026.2 | 555.8
torch.Size([1, 32, 32]) | 789.1 | 590.6 | 590.9 | 790.5 | 579.0
torch.Size([2, 32, 32]) | 806.9 | 596.6 | 600.2 | 1085.6 | 573.8
torch.Size([4, 32, 32]) | 852.0 | 597.9 | 588.2 | 1098.9 | 574.7
torch.Size([8, 32, 32]) | 914.2 | 597.8 | 591.4 | 1090.3 | 585.7
torch.Size([16, 32, 32]) | 1063.0 | 604.6 | 597.3 | 1094.0 | 580.5
torch.Size([32, 32, 32]) | 1302.0 | 602.0 | 598.9 | 1090.3 | 583.6
torch.Size([64, 32, 32]) | 1861.7 | 601.1 | 599.8 | 1113.4 | 588.6
torch.Size([128, 32, 32]) | 3251.0 | 619.6 | 595.3 | 1106.8 | 608.9
torch.Size([1, 64, 64]) | 978.6 | 842.7 | 778.6 | 1071.4 | 825.8
torch.Size([2, 64, 64]) | 1072.3 | 845.7 | 785.4 | 1400.6 | 829.0
torch.Size([4, 64, 64]) | 1051.9 | 842.9 | 796.1 | 1352.2 | 788.2
torch.Size([8, 64, 64]) | 1090.3 | 834.1 | 805.2 | 1382.6 | 804.7
torch.Size([16, 64, 64]) | 1206.9 | 835.7 | 802.2 | 1365.6 | 801.2
torch.Size([32, 64, 64]) | 1671.2 | 846.5 | 794.5 | 1345.1 | 814.2
torch.Size([64, 64, 64]) | 2759.3 | 848.5 | 795.4 | 1409.7 | 832.9
torch.Size([128, 64, 64]) | 4928.6 | 877.4 | 848.3 | 1439.0 | 883.9
torch.Size([1, 128, 128]) | 1315.6 | 1158.4 | 1130.0 | 1301.3 | 1177.1
torch.Size([2, 128, 128]) | 1334.7 | 1198.2 | 1186.6 | 1703.9 | 1209.5
torch.Size([4, 128, 128]) | 1374.6 | 1200.7 | 1266.2 | 1640.6 | 1272.3
torch.Size([8, 128, 128]) | 1453.6 | 1215.9 | 1287.3 | 1669.1 | 1288.7
torch.Size([16, 128, 128]) | 1882.1 | 1244.9 | 1337.6 | 1698.8 | 1347.1
torch.Size([32, 128, 128]) | 2789.0 | 1284.5 | 1398.6 | 1747.6 | 1396.3
torch.Size([64, 128, 128]) | 4763.0 | 1425.2 | 1581.7 | 1921.0 | 1584.1
torch.Size([128, 128, 128]) | 8835.9 | 1808.9 | 1968.7 | 2197.6 | 1961.8
torch.Size([1, 512, 512]) | 4369.9 | 4577.6 | 4804.0 | 4331.4 | 4599.0
torch.Size([2, 512, 512]) | 4635.9 | 4850.1 | 5159.1 | 5315.4 | 4845.5
torch.Size([4, 512, 512]) | 5367.5 | 5261.6 | 6134.7 | 5807.8 | 5345.2
torch.Size([8, 512, 512]) | 7025.2 | 6184.5 | 7065.6 | 6711.6 | 6303.9
torch.Size([16, 512, 512]) | 10221.3 | 7849.7 | 8820.1 | 8323.6 | 7992.1
torch.Size([32, 512, 512]) | 16574.8 | 11208.4 | 12284.3 | 11704.7 | 11394.4
torch.Size([64, 512, 512]) | 29500.1 | 18043.1 | 19249.3 | 18744.0 | 18242.1
torch.Size([128, 512, 512]) | 56783.3 | 33903.9 | 34713.5 | 33893.8 | 34041.8
torch.Size([1, 1024, 1024]) | 14864.5 | 15714.6 | 16128.1 | 14726.7 | 14992.6
torch.Size([2, 1024, 1024]) | 17891.0 | 18553.3 | 19111.6 | 19271.5 | 19283.0
torch.Size([4, 1024, 1024]) | 22143.4 | 21909.2 | 23667.1 | 22698.9 | 22713.8
torch.Size([8, 1024, 1024]) | 30621.1 | 28669.9 | 30822.9 | 29725.0 | 29760.8
torch.Size([16, 1024, 1024]) | 47045.9 | 41900.0 | 44353.8 | 43215.6 | 43237.5
torch.Size([32, 1024, 1024]) | 79245.5 | 68316.9 | 70959.0 | 69506.4 | 69876.7
torch.Size([64, 1024, 1024]) | 147973.9 | 121120.6 | 124601.1 | 122084.4 | 122578.7
torch.Size([128, 1024, 1024]) | 295586.2 | 232871.8 | 237421.8 | 233765.3 | 234704.6
Times are in microseconds (us).
```
</details>
Here's the details of how the tests were performed:
* CUSOLVER - Only call `cusolver` for all problem sizes.
* MAGMA - Only call `magma` for all problem sizes (this is the current master branch).
* CUBLAS - Only call `cublas` for all problem sizes.
* cuSOLVER / MAGMA - Use cusolver for `batch_size == 1` and magma for all others.
* TEST ALL - Employ heuristics to switch between cublas/cusolver/magma. This yields the best overall results (this PR).
Script for reproducing the results:
<details>
``` python
import torch
import pickle
import itertools
from torch.utils.benchmark import Timer
import sys
shapes = [1, 2, 8, 16, 32, 64, 128, 512, 1024]
batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,)]
results = []
num_threads = 1
dtype = torch.float64
repeats = 2
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
def lu_factorize_solve(mat, b):
lu_data = torch.lu(mat)
x = torch.lu_solve(b, *lu_data)
for shape, batch in itertools.product(shapes, batches):
mat = torch.randn(*batch, shape, shape, dtype=dtype, device='cuda')
b = torch.randn(*batch, shape, 1, dtype=dtype, device='cuda')
tasks = [("lu_factorize_solve(mat, b)", "lu_solve CUSOLVER")]
print("shape: ", shape, " batch: ", batch)
timers = [Timer(stmt=stmt, num_threads=num_threads, label=f"LU solve CUDA {dtype}",
sub_label=f"{mat.shape}", description=label, globals=globals()) for stmt, label in tasks]
for i, timer in enumerate(timers * repeats):
results.append(
pickle.dumps(timer.blocked_autorange())
)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
f = open("cusolver_lu_solve.pickle", "wb")
pickle.dump(results, f)
f.close()
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59148
Reviewed By: H-Huang
Differential Revision: D29160609
Pulled By: mruberry
fbshipit-source-id: 7280f25db1e66aa650ea15608a6dc5d688fb4db2