Faster backwards for cumsum and cumprod (#53711)
Summary:
Provides a faster formula for `cumprod` in the case when the input has zeros. This formula is non-differentiable, so we leave the previous formula for the cases when `at::GradMode::is_enabled()`.
This new formula gives up to x10 and x30 speed-ups in CPU and GPU (see the benchmarks below).
The `cumsum` backward formula was rewritten so that no copies are necessary. We also removed a double negation in its formula. This gives a significant speed-up in CPU, while being almost as efficient as the formula with copies in GPU. We can see this speed-up when comparing the "No zeros" part of the benchmark.
Benchmarks:
nb. It is worth noting that the script tests the forward and the backward for `cumprod`, so the speed-ups should be even larger than those announced here.
<details>
<summary>Script</summary>
```python
from IPython import get_ipython
import torch
from itertools import product
torch.manual_seed(13)
torch.set_num_threads(1)
ipython = get_ipython()
cpu = torch.device('cpu')
cuda = torch.device('cuda')
def run_test(ndims, size, size_prod, zeros, device):
print(f"ndims: {ndims}, tensor_size: {size}, size_prod: {size_prod}, zeros: {zeros}, device: {device}")
for dim in range(ndims):
sizes = ndims * [size]
sizes[dim] = size_prod
tensor = torch.rand(*sizes, device=device)
with torch.no_grad():
if zeros:
# Set 0.1 of them to zero
p_drop = 0.1
mask = torch.full_like(tensor, 1.0 - p_drop)
tensor = tensor * torch.bernoulli(mask)
else:
tensor = tensor + 1e-3
tensor.requires_grad_()
grad = torch.ones_like(tensor)
# We test both forward + backward, meaning that the speed-up is actually greater than reported
# That being said, this is more realistic than doing `retain_graph=True`
command = "torch.autograd.grad([tensor.cumprod(dim)], [tensor], grad_outputs=[grad])"
if device == cuda:
command += "; torch.cuda.synchronize()"
ipython.magic(f"timeit {command}")
print()
for device, zeros in product([cuda, cpu], [True, False]):
run_test(3, 300, 10, zeros, device)
run_test(3, 300, 100, zeros, device)
if device == cuda:
run_test(3, 300, 300, zeros, device)
```
</details>
<details>
<summary>CPU This PR (Some regression small tensors, x4 speed-up large tensors)</summary>
```
Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: True, device: cpu
28.2 ms ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
29.8 ms ± 78.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
24.5 ms ± 29.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: True, device: cpu
414 ms ± 3.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
428 ms ± 4.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
382 ms ± 3.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
No Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: False, device: cpu
3.11 ms ± 9.72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.83 ms ± 3.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.08 ms ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: False, device: cpu
92.2 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
101 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
87 ms ± 170 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
</details>
<details>
<summary>CUDA This PR (7-30x speed-up)</summary>
```
Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: True, device: cuda
1.46 ms ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.48 ms ± 3.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.93 ms ± 8.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: True, device: cuda
10.5 ms ± 914 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.6 ms ± 509 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
11.7 ms ± 864 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
ndims: 3, tensor_size: 300, size_prod: 300, zeros: True, device: cuda
30.3 ms ± 5.16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
30.6 ms ± 6.44 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
32.2 ms ± 2.34 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
No Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: False, device: cuda
248 µs ± 335 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
252 µs ± 186 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
438 µs ± 254 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: False, device: cuda
2.1 ms ± 193 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.16 ms ± 380 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.59 ms ± 398 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
ndims: 3, tensor_size: 300, size_prod: 300, zeros: False, device: cuda
6.3 ms ± 857 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.39 ms ± 288 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.15 ms ± 233 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
</details>
<details>
<summary>CPU master</summary>
```
Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: True, device: cpu
8.27 ms ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.8 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
28.2 ms ± 74.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: True, device: cpu
1.53 s ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.95 s ± 4.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.86 s ± 3.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
No Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: False, device: cpu
3.42 ms ± 20 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.25 ms ± 3.65 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.34 ms ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: False, device: cpu
104 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
117 ms ± 99.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
94.8 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
</details>
<details>
<summary>CUDA master</summary>
```
Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: True, device: cuda
912 µs ± 431 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.05 ms ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.74 ms ± 381 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: True, device: cuda
71.3 ms ± 7.91 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
85.4 ms ± 9.82 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
119 ms ± 6.21 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
ndims: 3, tensor_size: 300, size_prod: 300, zeros: True, device: cuda
646 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
776 ms ± 81.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
917 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
No Zeros:
ndims: 3, tensor_size: 300, size_prod: 10, zeros: False, device: cuda
301 µs ± 893 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
308 µs ± 236 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
592 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
ndims: 3, tensor_size: 300, size_prod: 100, zeros: False, device: cuda
2.61 ms ± 375 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.68 ms ± 524 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.38 ms ± 736 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
ndims: 3, tensor_size: 300, size_prod: 300, zeros: False, device: cuda
7.89 ms ± 848 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
8.03 ms ± 517 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.24 ms ± 405 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
</details>
cc nikitaved
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53711
Reviewed By: jbschlosser
Differential Revision: D27059662
Pulled By: anjali411
fbshipit-source-id: be610d5590c0199b4412dff66fac47666faaff9d