`torch.linalg.householder_product` faster backward (#63880)
Summary:
This PR implements a much more efficient algorithm. This algorithm allows to achieve MASSIVE speed-ups, especially for batched and/or larger double-precision inputs.
Here are some benchmarks:
<details>
<summary>Testing script</summary>
```python
from IPython import get_ipython
import torch
import itertools
torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()
cpu = torch.device('cpu')
cuda = torch.device('cuda')
def generate_input(shape, dtype=torch.double, device=cpu):
eigvals = torch.rand(*shape[:-1], dtype=dtype, device=device)
eigvecs = torch.rand(*shape, dtype=dtype, device=device)
input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse()
input.requires_grad_(True)
tau = torch.rand(*shape[:-1], dtype=dtype, device=device)
tau.requires_grad_(True)
return input, tau
def run_test(shape, device, dtype):
print(f"shape: {shape}, device: {device}, dtype: {dtype}")
a, tau = generate_input(shape, dtype=dtype, device=device)
prod = torch.linalg.householder_product(a, tau)
ones_prod = torch.ones_like(prod)
command = "torch.autograd.backward((prod,), (ones_prod), retain_graph=True)"
if device == cuda:
command = command + "; torch.cuda.synchronize()"
ipython.magic(f"timeit {command}")
print()
dtypes = [torch.float, torch.double]
devices = [cpu, cuda]
#devices = [cuda]
sizes = [
(10, 10),
(1000, 10, 10),
(100, 100),
(1000, 100, 100),
(1000, 1000),
(10, 1000, 1000),
]
for device, dtype, size in itertools.product(devices, dtypes, sizes):
run_test(size, device, dtype)
```
</details>
<details>
<summary>This PR, cuda float32</summary>
```
shape: (10, 10), device: cuda, dtype: torch.float32
1.33 ms ± 1.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cuda, dtype: torch.float32
1.52 ms ± 40.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (100, 100), device: cuda, dtype: torch.float32
10.8 ms ± 9.62 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (1000, 100, 100), device: cuda, dtype: torch.float32
127 ms ± 8.45 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
shape: (1000, 1000), device: cuda, dtype: torch.float32
151 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
shape: (10, 1000, 1000), device: cuda, dtype: torch.float32
981 ms ± 91.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>Master, cuda float32</summary>
```
shape: (10, 10), device: cuda, dtype: torch.float32
1.64 ms ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cuda, dtype: torch.float32
298 ms ± 463 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (100, 100), device: cuda, dtype: torch.float32
15.4 ms ± 41.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (1000, 100, 100), device: cuda, dtype: torch.float32
5.36 s ± 711 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cuda, dtype: torch.float32
1.64 s ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cuda, dtype: torch.float32
15.7 s ± 43.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>This PR, cuda float64</summary>
```
shape: (10, 10), device: cuda, dtype: torch.float64
1.14 ms ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cuda, dtype: torch.float64
2.22 ms ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (100, 100), device: cuda, dtype: torch.float64
10.6 ms ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (1000, 100, 100), device: cuda, dtype: torch.float64
287 ms ± 84.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cuda, dtype: torch.float64
236 ms ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cuda, dtype: torch.float64
1.88 s ± 88.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>Master, cuda float64</summary>
```
shape: (10, 10), device: cuda, dtype: torch.float64
1.58 ms ± 8.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cuda, dtype: torch.float64
308 ms ± 213 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (100, 100), device: cuda, dtype: torch.float64
79 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
shape: (1000, 100, 100), device: cuda, dtype: torch.float64
54.2 s ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cuda, dtype: torch.float64
31.5 s ± 698 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cuda, dtype: torch.float64
4min 45s ± 2.48 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>This PR, cpu float32</summary>
```
shape: (10, 10), device: cpu, dtype: torch.float32
476 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 10, 10), device: cpu, dtype: torch.float32
5.1 ms ± 100 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (100, 100), device: cpu, dtype: torch.float32
4.38 ms ± 4.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (1000, 100, 100), device: cpu, dtype: torch.float32
1.55 s ± 6.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cpu, dtype: torch.float32
745 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cpu, dtype: torch.float32
5.44 s ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>Master, cpu float32</summary>
```
shape: (10, 10), device: cpu, dtype: torch.float32
387 µs ± 645 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cpu, dtype: torch.float32
12.3 ms ± 23.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (100, 100), device: cpu, dtype: torch.float32
39.4 ms ± 80.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
shape: (1000, 100, 100), device: cpu, dtype: torch.float32
29.1 s ± 44.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cpu, dtype: torch.float32
9.42 s ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cpu, dtype: torch.float32
1min 50s ± 282 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>This PR, cpu float64</summary>
```
shape: (10, 10), device: cpu, dtype: torch.float64
381 µs ± 761 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cpu, dtype: torch.float64
6.19 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (100, 100), device: cpu, dtype: torch.float64
4.6 ms ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (1000, 100, 100), device: cpu, dtype: torch.float64
2.59 s ± 8.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cpu, dtype: torch.float64
1.07 s ± 5.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cpu, dtype: torch.float64
14.4 s ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
<details>
<summary>Master, cpu float64</summary>
```
shape: (10, 10), device: cpu, dtype: torch.float64
395 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
shape: (1000, 10, 10), device: cpu, dtype: torch.float64
14.6 ms ± 9.76 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
shape: (100, 100), device: cpu, dtype: torch.float64
45.5 ms ± 154 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
shape: (1000, 100, 100), device: cpu, dtype: torch.float64
33.1 s ± 69.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (1000, 1000), device: cpu, dtype: torch.float64
19.3 s ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
shape: (10, 1000, 1000), device: cpu, dtype: torch.float64
3min 30s ± 1.29 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63880
Reviewed By: soulitzer
Differential Revision: D30639435
Pulled By: anjali411
fbshipit-source-id: 127789943ae56e2f1dd03e0fe76ef7b6db86bcf0