pytorch
064206df - Performance and memory improvements to batched torch.linalg.solve (2nd attempt) (#71756)

Commit
3 years ago
Performance and memory improvements to batched torch.linalg.solve (2nd attempt) (#71756) Summary: Previous PR with the same content: https://github.com/pytorch/pytorch/pull/69752. Opening a new PR by request: https://github.com/pytorch/pytorch/pull/69752#issuecomment-1020829812. ------ Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system. With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving. Here's a basic comparison: ```python # BEFORE THE PR In [1]: import torch In [2]: a = torch.randn(256, 256) In [3]: b = torch.randn(1024, 256, 2) In [4]: %%timeit ...: torch.linalg.solve(a, b) ...: ...: 329 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) # WITH THIS PR In [1]: import torch In [2]: a = torch.randn(256, 256) In [3]: b = torch.randn(1024, 256, 2) In [4]: %%timeit ...: torch.linalg.solve(a, b) ...: ...: 21.4 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` Fixes https://github.com/pytorch/pytorch/issues/71406, fixes https://github.com/pytorch/pytorch/issues/71610 Pull Request resolved: https://github.com/pytorch/pytorch/pull/71756 Reviewed By: ngimel Differential Revision: D33771981 Pulled By: mruberry fbshipit-source-id: 0917ee36a3eb622ff75d54787b1bffe26b41cb4a (cherry picked from commit 9c30a05aaa972bc02dfc94c3d2463f0c5ee0c58c)
Author
Committer
Parents
Loading