pytorch
1dee99c9 - LU Solve using cublas and cusolver (#59148)

Commit
4 years ago
LU Solve using cublas and cusolver (#59148) Summary: This PR introduces cuSOLVER and cuBLAS for the `lu_solve` routine. Solves a part of https://github.com/pytorch/pytorch/issues/47953. Since usage of cuSOLVER with MAGMA introduces performance regressions in MAGMA (https://github.com/pytorch/pytorch/issues/56590), we use heuristics for determining when to call cuSOLVER, cuBLAS or MAGMA depending on the batch and matrix sizes. The 64-bit cuSOLVER API is not introduced in this PR since there are several problems with the LU factorization using cusolver (https://github.com/pytorch/pytorch/pull/59148). The following are performance benchmarks using various configurations: <details> ``` [--------------------------------------------------------- LU solve CUDA torch.float64 ----------------------------------------------------------] | lu_solve CUSOLVER | lu_solve MAGMA | lu_solve CUBLAS | lu_solve cuSOLVER/MAGMA | lu_solve TEST ALL 1 threads: --------------------------------------------------------------------------------------------------------------------------------------- torch.Size([1, 1, 1]) | 703.4 | 489.8 | 511.8 | 710.1 | 487.1 torch.Size([2, 1, 1]) | 738.9 | 504.1 | 513.0 | 958.2 | 494.4 torch.Size([4, 1, 1]) | 790.7 | 514.7 | 506.8 | 983.9 | 540.2 torch.Size([8, 1, 1]) | 865.3 | 496.4 | 514.7 | 975.2 | 520.0 torch.Size([16, 1, 1]) | 955.5 | 483.9 | 508.3 | 937.6 | 526.5 torch.Size([32, 1, 1]) | 1167.7 | 495.2 | 511.2 | 934.0 | 528.7 torch.Size([64, 1, 1]) | 1730.0 | 492.1 | 537.8 | 936.4 | 533.2 torch.Size([128, 1, 1]) | 2748.4 | 499.7 | 526.5 | 982.9 | 540.8 torch.Size([1, 2, 2]) | 724.6 | 498.2 | 541.7 | 715.0 | 504.7 torch.Size([2, 2, 2]) | 737.0 | 514.3 | 527.6 | 934.5 | 524.5 torch.Size([4, 2, 2]) | 750.5 | 524.1 | 537.4 | 935.5 | 543.0 torch.Size([8, 2, 2]) | 844.8 | 513.7 | 538.9 | 953.3 | 534.4 torch.Size([16, 2, 2]) | 1013.1 | 521.9 | 530.0 | 932.2 | 537.9 torch.Size([32, 2, 2]) | 1335.8 | 515.1 | 544.4 | 939.9 | 559.5 torch.Size([64, 2, 2]) | 1819.6 | 511.8 | 534.1 | 973.9 | 540.0 torch.Size([128, 2, 2]) | 3018.7 | 526.3 | 546.1 | 979.3 | 543.5 torch.Size([1, 8, 8]) | 732.5 | 524.9 | 532.9 | 762.4 | 516.8 torch.Size([2, 8, 8]) | 771.2 | 514.9 | 538.7 | 1007.5 | 531.1 torch.Size([4, 8, 8]) | 811.3 | 507.7 | 534.6 | 1002.2 | 548.5 torch.Size([8, 8, 8]) | 866.6 | 530.0 | 532.0 | 1016.1 | 562.9 torch.Size([16, 8, 8]) | 991.8 | 533.6 | 548.0 | 1022.6 | 548.5 torch.Size([32, 8, 8]) | 1271.7 | 541.2 | 534.7 | 1013.8 | 545.6 torch.Size([64, 8, 8]) | 1817.2 | 530.2 | 520.6 | 1008.7 | 566.3 torch.Size([128, 8, 8]) | 2678.7 | 531.6 | 552.2 | 1006.2 | 555.0 torch.Size([1, 16, 16]) | 738.2 | 546.1 | 536.6 | 775.6 | 540.1 torch.Size([2, 16, 16]) | 782.6 | 543.5 | 539.6 | 1010.9 | 541.1 torch.Size([4, 16, 16]) | 815.2 | 546.1 | 560.9 | 1012.5 | 553.1 torch.Size([8, 16, 16]) | 877.7 | 543.0 | 547.9 | 1012.8 | 551.5 torch.Size([16, 16, 16]) | 1008.7 | 549.2 | 562.7 | 1016.6 | 546.8 torch.Size([32, 16, 16]) | 1291.9 | 540.8 | 560.3 | 1055.8 | 539.3 torch.Size([64, 16, 16]) | 1846.3 | 553.5 | 556.0 | 1010.8 | 551.9 torch.Size([128, 16, 16]) | 2953.8 | 562.7 | 547.5 | 1026.2 | 555.8 torch.Size([1, 32, 32]) | 789.1 | 590.6 | 590.9 | 790.5 | 579.0 torch.Size([2, 32, 32]) | 806.9 | 596.6 | 600.2 | 1085.6 | 573.8 torch.Size([4, 32, 32]) | 852.0 | 597.9 | 588.2 | 1098.9 | 574.7 torch.Size([8, 32, 32]) | 914.2 | 597.8 | 591.4 | 1090.3 | 585.7 torch.Size([16, 32, 32]) | 1063.0 | 604.6 | 597.3 | 1094.0 | 580.5 torch.Size([32, 32, 32]) | 1302.0 | 602.0 | 598.9 | 1090.3 | 583.6 torch.Size([64, 32, 32]) | 1861.7 | 601.1 | 599.8 | 1113.4 | 588.6 torch.Size([128, 32, 32]) | 3251.0 | 619.6 | 595.3 | 1106.8 | 608.9 torch.Size([1, 64, 64]) | 978.6 | 842.7 | 778.6 | 1071.4 | 825.8 torch.Size([2, 64, 64]) | 1072.3 | 845.7 | 785.4 | 1400.6 | 829.0 torch.Size([4, 64, 64]) | 1051.9 | 842.9 | 796.1 | 1352.2 | 788.2 torch.Size([8, 64, 64]) | 1090.3 | 834.1 | 805.2 | 1382.6 | 804.7 torch.Size([16, 64, 64]) | 1206.9 | 835.7 | 802.2 | 1365.6 | 801.2 torch.Size([32, 64, 64]) | 1671.2 | 846.5 | 794.5 | 1345.1 | 814.2 torch.Size([64, 64, 64]) | 2759.3 | 848.5 | 795.4 | 1409.7 | 832.9 torch.Size([128, 64, 64]) | 4928.6 | 877.4 | 848.3 | 1439.0 | 883.9 torch.Size([1, 128, 128]) | 1315.6 | 1158.4 | 1130.0 | 1301.3 | 1177.1 torch.Size([2, 128, 128]) | 1334.7 | 1198.2 | 1186.6 | 1703.9 | 1209.5 torch.Size([4, 128, 128]) | 1374.6 | 1200.7 | 1266.2 | 1640.6 | 1272.3 torch.Size([8, 128, 128]) | 1453.6 | 1215.9 | 1287.3 | 1669.1 | 1288.7 torch.Size([16, 128, 128]) | 1882.1 | 1244.9 | 1337.6 | 1698.8 | 1347.1 torch.Size([32, 128, 128]) | 2789.0 | 1284.5 | 1398.6 | 1747.6 | 1396.3 torch.Size([64, 128, 128]) | 4763.0 | 1425.2 | 1581.7 | 1921.0 | 1584.1 torch.Size([128, 128, 128]) | 8835.9 | 1808.9 | 1968.7 | 2197.6 | 1961.8 torch.Size([1, 512, 512]) | 4369.9 | 4577.6 | 4804.0 | 4331.4 | 4599.0 torch.Size([2, 512, 512]) | 4635.9 | 4850.1 | 5159.1 | 5315.4 | 4845.5 torch.Size([4, 512, 512]) | 5367.5 | 5261.6 | 6134.7 | 5807.8 | 5345.2 torch.Size([8, 512, 512]) | 7025.2 | 6184.5 | 7065.6 | 6711.6 | 6303.9 torch.Size([16, 512, 512]) | 10221.3 | 7849.7 | 8820.1 | 8323.6 | 7992.1 torch.Size([32, 512, 512]) | 16574.8 | 11208.4 | 12284.3 | 11704.7 | 11394.4 torch.Size([64, 512, 512]) | 29500.1 | 18043.1 | 19249.3 | 18744.0 | 18242.1 torch.Size([128, 512, 512]) | 56783.3 | 33903.9 | 34713.5 | 33893.8 | 34041.8 torch.Size([1, 1024, 1024]) | 14864.5 | 15714.6 | 16128.1 | 14726.7 | 14992.6 torch.Size([2, 1024, 1024]) | 17891.0 | 18553.3 | 19111.6 | 19271.5 | 19283.0 torch.Size([4, 1024, 1024]) | 22143.4 | 21909.2 | 23667.1 | 22698.9 | 22713.8 torch.Size([8, 1024, 1024]) | 30621.1 | 28669.9 | 30822.9 | 29725.0 | 29760.8 torch.Size([16, 1024, 1024]) | 47045.9 | 41900.0 | 44353.8 | 43215.6 | 43237.5 torch.Size([32, 1024, 1024]) | 79245.5 | 68316.9 | 70959.0 | 69506.4 | 69876.7 torch.Size([64, 1024, 1024]) | 147973.9 | 121120.6 | 124601.1 | 122084.4 | 122578.7 torch.Size([128, 1024, 1024]) | 295586.2 | 232871.8 | 237421.8 | 233765.3 | 234704.6 Times are in microseconds (us). ``` </details> Here's the details of how the tests were performed: * CUSOLVER - Only call `cusolver` for all problem sizes. * MAGMA - Only call `magma` for all problem sizes (this is the current master branch). * CUBLAS - Only call `cublas` for all problem sizes. * cuSOLVER / MAGMA - Use cusolver for `batch_size == 1` and magma for all others. * TEST ALL - Employ heuristics to switch between cublas/cusolver/magma. This yields the best overall results (this PR). Script for reproducing the results: <details> ``` python import torch import pickle import itertools from torch.utils.benchmark import Timer import sys shapes = [1, 2, 8, 16, 32, 64, 128, 512, 1024] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,)] results = [] num_threads = 1 dtype = torch.float64 repeats = 2 from torch.testing._internal.common_utils import random_hermitian_pd_matrix def lu_factorize_solve(mat, b): lu_data = torch.lu(mat) x = torch.lu_solve(b, *lu_data) for shape, batch in itertools.product(shapes, batches): mat = torch.randn(*batch, shape, shape, dtype=dtype, device='cuda') b = torch.randn(*batch, shape, 1, dtype=dtype, device='cuda') tasks = [("lu_factorize_solve(mat, b)", "lu_solve CUSOLVER")] print("shape: ", shape, " batch: ", batch) timers = [Timer(stmt=stmt, num_threads=num_threads, label=f"LU solve CUDA {dtype}", sub_label=f"{mat.shape}", description=label, globals=globals()) for stmt, label in tasks] for i, timer in enumerate(timers * repeats): results.append( pickle.dumps(timer.blocked_autorange()) ) print(f"\r{i + 1} / {len(timers) * repeats}", end="") sys.stdout.flush() f = open("cusolver_lu_solve.pickle", "wb") pickle.dump(results, f) f.close() ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/59148 Reviewed By: H-Huang Differential Revision: D29160609 Pulled By: mruberry fbshipit-source-id: 7280f25db1e66aa650ea15608a6dc5d688fb4db2
Author
Parents
Loading