Performance and memory improvements to batched torch.linalg.solve (#69752)
Summary:
Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system.
With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving.
Here's a basic comparison:
```python
# BEFORE THE PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
...: torch.linalg.solve(a, b)
...:
...:
329 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# WITH THIS PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
...: torch.linalg.solve(a, b)
...:
...:
21.4 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69752
Reviewed By: albanD
Differential Revision: D33028236
Pulled By: mruberry
fbshipit-source-id: 7a0dd443cd0ece81777c68b29438750f6524ac24