pytorch
06c37876 - `torch.linalg.householder_product` faster backward (#63880)

Commit
3 years ago
`torch.linalg.householder_product` faster backward (#63880) Summary: This PR implements a much more efficient algorithm. This algorithm allows to achieve MASSIVE speed-ups, especially for batched and/or larger double-precision inputs. Here are some benchmarks: <details> <summary>Testing script</summary> ```python from IPython import get_ipython import torch import itertools torch.manual_seed(13) #torch.set_num_threads(1) ipython = get_ipython() cpu = torch.device('cpu') cuda = torch.device('cuda') def generate_input(shape, dtype=torch.double, device=cpu): eigvals = torch.rand(*shape[:-1], dtype=dtype, device=device) eigvecs = torch.rand(*shape, dtype=dtype, device=device) input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse() input.requires_grad_(True) tau = torch.rand(*shape[:-1], dtype=dtype, device=device) tau.requires_grad_(True) return input, tau def run_test(shape, device, dtype): print(f"shape: {shape}, device: {device}, dtype: {dtype}") a, tau = generate_input(shape, dtype=dtype, device=device) prod = torch.linalg.householder_product(a, tau) ones_prod = torch.ones_like(prod) command = "torch.autograd.backward((prod,), (ones_prod), retain_graph=True)" if device == cuda: command = command + "; torch.cuda.synchronize()" ipython.magic(f"timeit {command}") print() dtypes = [torch.float, torch.double] devices = [cpu, cuda] #devices = [cuda] sizes = [ (10, 10), (1000, 10, 10), (100, 100), (1000, 100, 100), (1000, 1000), (10, 1000, 1000), ] for device, dtype, size in itertools.product(devices, dtypes, sizes): run_test(size, device, dtype) ``` </details> <details> <summary>This PR, cuda float32</summary> ``` shape: (10, 10), device: cuda, dtype: torch.float32 1.33 ms ± 1.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cuda, dtype: torch.float32 1.52 ms ± 40.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (100, 100), device: cuda, dtype: torch.float32 10.8 ms ± 9.62 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (1000, 100, 100), device: cuda, dtype: torch.float32 127 ms ± 8.45 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) shape: (1000, 1000), device: cuda, dtype: torch.float32 151 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) shape: (10, 1000, 1000), device: cuda, dtype: torch.float32 981 ms ± 91.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>Master, cuda float32</summary> ``` shape: (10, 10), device: cuda, dtype: torch.float32 1.64 ms ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cuda, dtype: torch.float32 298 ms ± 463 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (100, 100), device: cuda, dtype: torch.float32 15.4 ms ± 41.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (1000, 100, 100), device: cuda, dtype: torch.float32 5.36 s ± 711 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cuda, dtype: torch.float32 1.64 s ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cuda, dtype: torch.float32 15.7 s ± 43.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>This PR, cuda float64</summary> ``` shape: (10, 10), device: cuda, dtype: torch.float64 1.14 ms ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cuda, dtype: torch.float64 2.22 ms ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (100, 100), device: cuda, dtype: torch.float64 10.6 ms ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (1000, 100, 100), device: cuda, dtype: torch.float64 287 ms ± 84.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cuda, dtype: torch.float64 236 ms ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cuda, dtype: torch.float64 1.88 s ± 88.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>Master, cuda float64</summary> ``` shape: (10, 10), device: cuda, dtype: torch.float64 1.58 ms ± 8.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cuda, dtype: torch.float64 308 ms ± 213 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (100, 100), device: cuda, dtype: torch.float64 79 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) shape: (1000, 100, 100), device: cuda, dtype: torch.float64 54.2 s ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cuda, dtype: torch.float64 31.5 s ± 698 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cuda, dtype: torch.float64 4min 45s ± 2.48 s per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>This PR, cpu float32</summary> ``` shape: (10, 10), device: cpu, dtype: torch.float32 476 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 10, 10), device: cpu, dtype: torch.float32 5.1 ms ± 100 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (100, 100), device: cpu, dtype: torch.float32 4.38 ms ± 4.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (1000, 100, 100), device: cpu, dtype: torch.float32 1.55 s ± 6.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cpu, dtype: torch.float32 745 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cpu, dtype: torch.float32 5.44 s ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>Master, cpu float32</summary> ``` shape: (10, 10), device: cpu, dtype: torch.float32 387 µs ± 645 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cpu, dtype: torch.float32 12.3 ms ± 23.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (100, 100), device: cpu, dtype: torch.float32 39.4 ms ± 80.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) shape: (1000, 100, 100), device: cpu, dtype: torch.float32 29.1 s ± 44.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cpu, dtype: torch.float32 9.42 s ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cpu, dtype: torch.float32 1min 50s ± 282 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>This PR, cpu float64</summary> ``` shape: (10, 10), device: cpu, dtype: torch.float64 381 µs ± 761 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cpu, dtype: torch.float64 6.19 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (100, 100), device: cpu, dtype: torch.float64 4.6 ms ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (1000, 100, 100), device: cpu, dtype: torch.float64 2.59 s ± 8.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cpu, dtype: torch.float64 1.07 s ± 5.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cpu, dtype: torch.float64 14.4 s ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>Master, cpu float64</summary> ``` shape: (10, 10), device: cpu, dtype: torch.float64 395 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) shape: (1000, 10, 10), device: cpu, dtype: torch.float64 14.6 ms ± 9.76 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) shape: (100, 100), device: cpu, dtype: torch.float64 45.5 ms ± 154 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) shape: (1000, 100, 100), device: cpu, dtype: torch.float64 33.1 s ± 69.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (1000, 1000), device: cpu, dtype: torch.float64 19.3 s ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) shape: (10, 1000, 1000), device: cpu, dtype: torch.float64 3min 30s ± 1.29 s per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/63880 Reviewed By: soulitzer Differential Revision: D30639435 Pulled By: anjali411 fbshipit-source-id: 127789943ae56e2f1dd03e0fe76ef7b6db86bcf0
Author
Parents
Loading