pytorch
e1338016 - cuSOLVER path for LU factorization in CUDA. (#56887)

Commit
3 years ago
cuSOLVER path for LU factorization in CUDA. (#56887) Summary: This PR adds cuSOLVER path for `torch.lu`. Performance comparison results: https://github.com/pytorch/pytorch/issues/53879#issuecomment-830635381 Code for reproducing performance results: https://github.com/pytorch/pytorch/pull/56887#issuecomment-843212868 The following heuristics are used for choosing cuSOLVER over MAGMA: * If batch size == 1 OR (batch size <= 8 AND shape <= 16), choose cuSOLVER over MAGMA. * For all other cases use MAGMA. See also https://github.com/pytorch/pytorch/issues/47953. Following are the performance results between the MASTER branch and the current changes: <details> ``` [-------------------------- LU factorization (ATen) torch.float64 ---------------------------] | lu_factorize CURRENT | lu_factorize MASTER 1 threads: ----------------------------------------------------------------------------------- torch.Size([1, 1, 1]) | 363.9 | 284.1 torch.Size([2, 1, 1]) | 354.8 | 271.8 torch.Size([4, 1, 1]) | 393.7 | 278.0 torch.Size([8, 1, 1]) | 459.3 | 279.1 torch.Size([16, 1, 1]) | 524.2 | 288.9 torch.Size([32, 1, 1]) | 525.1 | 281.2 torch.Size([64, 1, 1]) | 524.5 | 281.7 torch.Size([128, 1, 1]) | 522.8 | 285.2 torch.Size([1, 2, 2]) | 360.4 | 277.7 torch.Size([2, 2, 2]) | 372.9 | 279.2 torch.Size([4, 2, 2]) | 419.4 | 278.3 torch.Size([8, 2, 2]) | 475.7 | 279.2 torch.Size([16, 2, 2]) | 530.0 | 299.5 torch.Size([32, 2, 2]) | 530.0 | 294.5 torch.Size([64, 2, 2]) | 531.0 | 291.5 torch.Size([128, 2, 2]) | 544.4 | 292.3 torch.Size([1, 8, 8]) | 372.6 | 292.8 torch.Size([2, 8, 8]) | 380.9 | 296.2 torch.Size([4, 8, 8]) | 420.0 | 293.4 torch.Size([8, 8, 8]) | 490.6 | 294.6 torch.Size([16, 8, 8]) | 535.6 | 296.5 torch.Size([32, 8, 8]) | 534.7 | 302.1 torch.Size([64, 8, 8]) | 539.1 | 305.5 torch.Size([128, 8, 8]) | 540.7 | 296.5 torch.Size([1, 16, 16]) | 345.0 | 303.2 torch.Size([2, 16, 16]) | 405.0 | 306.3 torch.Size([4, 16, 16]) | 482.8 | 305.6 torch.Size([8, 16, 16]) | 596.3 | 305.9 torch.Size([16, 16, 16]) | 539.6 | 304.4 torch.Size([32, 16, 16]) | 542.2 | 305.8 torch.Size([64, 16, 16]) | 556.1 | 311.0 torch.Size([128, 16, 16]) | 545.1 | 308.1 torch.Size([1, 32, 32]) | 432.7 | 342.4 torch.Size([2, 32, 32]) | 582.6 | 341.8 torch.Size([4, 32, 32]) | 580.4 | 344.4 torch.Size([8, 32, 32]) | 586.5 | 343.8 torch.Size([16, 32, 32]) | 582.9 | 346.0 torch.Size([32, 32, 32]) | 574.4 | 343.7 torch.Size([64, 32, 32]) | 562.8 | 350.8 torch.Size([128, 32, 32]) | 568.3 | 349.8 torch.Size([1, 64, 64]) | 537.1 | 518.4 torch.Size([2, 64, 64]) | 766.5 | 539.1 torch.Size([4, 64, 64]) | 771.6 | 551.9 torch.Size([8, 64, 64]) | 783.4 | 556.0 torch.Size([16, 64, 64]) | 798.8 | 555.3 torch.Size([32, 64, 64]) | 795.6 | 548.6 torch.Size([64, 64, 64]) | 804.2 | 580.4 torch.Size([128, 64, 64]) | 837.6 | 616.9 torch.Size([1, 128, 128]) | 844.7 | 848.9 torch.Size([2, 128, 128]) | 1096.7 | 873.3 torch.Size([4, 128, 128]) | 1117.9 | 884.8 torch.Size([8, 128, 128]) | 1138.1 | 903.6 torch.Size([16, 128, 128]) | 1169.1 | 943.9 torch.Size([32, 128, 128]) | 1204.8 | 981.4 torch.Size([64, 128, 128]) | 1336.6 | 1105.8 torch.Size([128, 128, 128]) | 1639.4 | 1473.3 torch.Size([1, 512, 512]) | 3714.3 | 3928.6 torch.Size([2, 512, 512]) | 4388.3 | 4179.7 torch.Size([4, 512, 512]) | 4765.4 | 4536.9 torch.Size([8, 512, 512]) | 5615.2 | 5441.1 torch.Size([16, 512, 512]) | 7203.6 | 7130.2 torch.Size([32, 512, 512]) | 10580.5 | 10503.9 torch.Size([64, 512, 512]) | 17374.8 | 17349.6 torch.Size([128, 512, 512]) | 32542.3 | 32548.8 torch.Size([1, 1024, 1024]) | 10041.5 | 14292.3 torch.Size([2, 1024, 1024]) | 17126.6 | 16971.0 torch.Size([4, 1024, 1024]) | 20591.0 | 20490.8 torch.Size([8, 1024, 1024]) | 27682.8 | 27560.7 torch.Size([16, 1024, 1024]) | 41035.2 | 41035.8 torch.Size([32, 1024, 1024]) | 67091.8 | 67345.9 torch.Size([64, 1024, 1024]) | 119612.3 | 119782.3 torch.Size([128, 1024, 1024]) | 230095.5 | 230766.2 Times are in microseconds (us). ``` </details> The main reason why a performance regression can be seen is related to this issue (https://github.com/pytorch/pytorch/issues/55122) and there seems to be no easy way to fix this (atleast in this PR). Pull Request resolved: https://github.com/pytorch/pytorch/pull/56887 Reviewed By: ngimel Differential Revision: D29482342 Pulled By: mruberry fbshipit-source-id: 4fdedf21b0d5597b289e168dff61d5f5d7727fb1
Author
Parents
Loading