Performance and memory improvements to batched torch.linalg.solve (2nd attempt) (#71756)
Summary:
Previous PR with the same content: https://github.com/pytorch/pytorch/pull/69752. Opening a new PR by request: https://github.com/pytorch/pytorch/pull/69752#issuecomment-1020829812.
------
Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system.
With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving.
Here's a basic comparison:
```python
# BEFORE THE PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
...: torch.linalg.solve(a, b)
...:
...:
329 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# WITH THIS PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
...: torch.linalg.solve(a, b)
...:
...:
21.4 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
Fixes https://github.com/pytorch/pytorch/issues/71406, fixes https://github.com/pytorch/pytorch/issues/71610
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71756
Reviewed By: ngimel
Differential Revision: D33771981
Pulled By: mruberry
fbshipit-source-id: 0917ee36a3eb622ff75d54787b1bffe26b41cb4a
(cherry picked from commit 9c30a05aaa972bc02dfc94c3d2463f0c5ee0c58c)