cuSOLVER path for LU factorization in CUDA. (#56887)
Summary:
This PR adds cuSOLVER path for `torch.lu`.
Performance comparison results: https://github.com/pytorch/pytorch/issues/53879#issuecomment-830635381
Code for reproducing performance results: https://github.com/pytorch/pytorch/pull/56887#issuecomment-843212868
The following heuristics are used for choosing cuSOLVER over MAGMA:
* If batch size == 1 OR (batch size <= 8 AND shape <= 16), choose cuSOLVER over MAGMA.
* For all other cases use MAGMA.
See also https://github.com/pytorch/pytorch/issues/47953.
Following are the performance results between the MASTER branch and the current changes:
<details>
```
[-------------------------- LU factorization (ATen) torch.float64 ---------------------------]
| lu_factorize CURRENT | lu_factorize MASTER
1 threads: -----------------------------------------------------------------------------------
torch.Size([1, 1, 1]) | 363.9 | 284.1
torch.Size([2, 1, 1]) | 354.8 | 271.8
torch.Size([4, 1, 1]) | 393.7 | 278.0
torch.Size([8, 1, 1]) | 459.3 | 279.1
torch.Size([16, 1, 1]) | 524.2 | 288.9
torch.Size([32, 1, 1]) | 525.1 | 281.2
torch.Size([64, 1, 1]) | 524.5 | 281.7
torch.Size([128, 1, 1]) | 522.8 | 285.2
torch.Size([1, 2, 2]) | 360.4 | 277.7
torch.Size([2, 2, 2]) | 372.9 | 279.2
torch.Size([4, 2, 2]) | 419.4 | 278.3
torch.Size([8, 2, 2]) | 475.7 | 279.2
torch.Size([16, 2, 2]) | 530.0 | 299.5
torch.Size([32, 2, 2]) | 530.0 | 294.5
torch.Size([64, 2, 2]) | 531.0 | 291.5
torch.Size([128, 2, 2]) | 544.4 | 292.3
torch.Size([1, 8, 8]) | 372.6 | 292.8
torch.Size([2, 8, 8]) | 380.9 | 296.2
torch.Size([4, 8, 8]) | 420.0 | 293.4
torch.Size([8, 8, 8]) | 490.6 | 294.6
torch.Size([16, 8, 8]) | 535.6 | 296.5
torch.Size([32, 8, 8]) | 534.7 | 302.1
torch.Size([64, 8, 8]) | 539.1 | 305.5
torch.Size([128, 8, 8]) | 540.7 | 296.5
torch.Size([1, 16, 16]) | 345.0 | 303.2
torch.Size([2, 16, 16]) | 405.0 | 306.3
torch.Size([4, 16, 16]) | 482.8 | 305.6
torch.Size([8, 16, 16]) | 596.3 | 305.9
torch.Size([16, 16, 16]) | 539.6 | 304.4
torch.Size([32, 16, 16]) | 542.2 | 305.8
torch.Size([64, 16, 16]) | 556.1 | 311.0
torch.Size([128, 16, 16]) | 545.1 | 308.1
torch.Size([1, 32, 32]) | 432.7 | 342.4
torch.Size([2, 32, 32]) | 582.6 | 341.8
torch.Size([4, 32, 32]) | 580.4 | 344.4
torch.Size([8, 32, 32]) | 586.5 | 343.8
torch.Size([16, 32, 32]) | 582.9 | 346.0
torch.Size([32, 32, 32]) | 574.4 | 343.7
torch.Size([64, 32, 32]) | 562.8 | 350.8
torch.Size([128, 32, 32]) | 568.3 | 349.8
torch.Size([1, 64, 64]) | 537.1 | 518.4
torch.Size([2, 64, 64]) | 766.5 | 539.1
torch.Size([4, 64, 64]) | 771.6 | 551.9
torch.Size([8, 64, 64]) | 783.4 | 556.0
torch.Size([16, 64, 64]) | 798.8 | 555.3
torch.Size([32, 64, 64]) | 795.6 | 548.6
torch.Size([64, 64, 64]) | 804.2 | 580.4
torch.Size([128, 64, 64]) | 837.6 | 616.9
torch.Size([1, 128, 128]) | 844.7 | 848.9
torch.Size([2, 128, 128]) | 1096.7 | 873.3
torch.Size([4, 128, 128]) | 1117.9 | 884.8
torch.Size([8, 128, 128]) | 1138.1 | 903.6
torch.Size([16, 128, 128]) | 1169.1 | 943.9
torch.Size([32, 128, 128]) | 1204.8 | 981.4
torch.Size([64, 128, 128]) | 1336.6 | 1105.8
torch.Size([128, 128, 128]) | 1639.4 | 1473.3
torch.Size([1, 512, 512]) | 3714.3 | 3928.6
torch.Size([2, 512, 512]) | 4388.3 | 4179.7
torch.Size([4, 512, 512]) | 4765.4 | 4536.9
torch.Size([8, 512, 512]) | 5615.2 | 5441.1
torch.Size([16, 512, 512]) | 7203.6 | 7130.2
torch.Size([32, 512, 512]) | 10580.5 | 10503.9
torch.Size([64, 512, 512]) | 17374.8 | 17349.6
torch.Size([128, 512, 512]) | 32542.3 | 32548.8
torch.Size([1, 1024, 1024]) | 10041.5 | 14292.3
torch.Size([2, 1024, 1024]) | 17126.6 | 16971.0
torch.Size([4, 1024, 1024]) | 20591.0 | 20490.8
torch.Size([8, 1024, 1024]) | 27682.8 | 27560.7
torch.Size([16, 1024, 1024]) | 41035.2 | 41035.8
torch.Size([32, 1024, 1024]) | 67091.8 | 67345.9
torch.Size([64, 1024, 1024]) | 119612.3 | 119782.3
torch.Size([128, 1024, 1024]) | 230095.5 | 230766.2
Times are in microseconds (us).
```
</details>
The main reason why a performance regression can be seen is related to this issue (https://github.com/pytorch/pytorch/issues/55122) and there seems to be no easy way to fix this (atleast in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56887
Reviewed By: ngimel
Differential Revision: D29482342
Pulled By: mruberry
fbshipit-source-id: 4fdedf21b0d5597b289e168dff61d5f5d7727fb1