Faster cumsum and cumprod backwards (#60642)
Summary:
Piggybacking on https://github.com/pytorch/pytorch/pull/58747, now we can implement the backwards of `cumsum` and `cumprod` without tricks. This minimises the number of kernels that are launched in GPU, so we see a reasonable speed-up on GPU. We should also get a better stability for ill-conditioned inputs, as we do not perform any numerical tricks to get the result.
Note that the benchmarks test forward + backward, so the true speed-up on the backward should be even faster. Even more so in `cumsum`, as it requires less operations than the backward of `cumprod`.
<details>
<summary>
Test Script
</summary>
```python
from itertools import product
import torch
from torch.utils.benchmark import Compare, Timer
def get_timer(ndims, prod_dim, dim, num_threads, device):
size = [500]*ndims
size[dim] = prod_dim
x = torch.rand(*size, device=device, requires_grad=True)
# Make sure there are no zeros as the formula for the backward
# that we are testing is for when the backward has no zeros
with torch.no_grad():
x.add_(1e-3)
grad = torch.ones_like(x)
timer = Timer(
"torch.autograd.grad([x.cumprod(dim)], [x], grad_outputs=[grad])",
globals={"x": x, "dim": dim, "grad": grad},
label=f"Cumprod + Backwards {device}",
description=f"dim: {dim}",
sub_label=f"prod_dim: {prod_dim}",
num_threads=num_threads,
)
return timer.blocked_autorange(min_run_time=5)
def get_params():
ndims = 3
dims = range(ndims)
prod_dims = [10, 100, 500]
for dim, prod_dim, device in product(dims, prod_dims, ("cpu", "cuda")):
threads = (1, 2, 4) if device == "cpu" else (1,)
for num_threads in threads:
yield ndims, prod_dim, dim, num_threads, device
compare = Compare([get_timer(*params) for params in get_params()])
compare.trim_significant_figures()
compare.print()
```
</details>
<details>
<summary>
Benchmark PR
</summary>
```
[------------ Cumprod + Backwards cpu -------------]
| dim: 0 | dim: 1 | dim: 2
1 threads: -----------------------------------------
prod_dim: 10 | 11 | 14 | 12
prod_dim: 100 | 260 | 270 | 260
prod_dim: 500 | 1400 | 1550 | 1360
2 threads: -----------------------------------------
prod_dim: 10 | 6 | 6 | 6
prod_dim: 100 | 170 | 166 | 167
prod_dim: 500 | 902 | 950 | 858
4 threads: -----------------------------------------
prod_dim: 10 | 4 | 3 | 3
prod_dim: 100 | 110 | 108 | 106
prod_dim: 500 | 576 | 590 | 547
Times are in milliseconds (ms).
[------------ Cumprod + Backwards cuda ------------]
| dim: 0 | dim: 1 | dim: 2
1 threads: -----------------------------------------
prod_dim: 10 | 562 | 566 | 1075
prod_dim: 100 | 5388 | 5394 | 6697
prod_dim: 500 | 28170 | 27580 | 30740
Times are in microseconds (us).
```
</details>
<details>
<summary>
Benchmark master
</summary>
```
[------------ Cumprod + Backwards cpu -------------]
| dim: 0 | dim: 1 | dim: 2
1 threads: -----------------------------------------
prod_dim: 10 | 11 | 13 | 12
prod_dim: 100 | 270 | 270 | 256
prod_dim: 500 | 1500 | 1590 | 1300
2 threads: -----------------------------------------
prod_dim: 10 | 6 | 6 | 6
prod_dim: 100 | 170 | 170 | 164
prod_dim: 500 | 911 | 940 | 840
4 threads: -----------------------------------------
prod_dim: 10 | 4 | 4 | 4
prod_dim: 100 | 111 | 109 | 105
prod_dim: 500 | 570 | 590 | 536
Times are in milliseconds (ms).
[------------ Cumprod + Backwards cuda ------------]
| dim: 0 | dim: 1 | dim: 2
1 threads: -----------------------------------------
prod_dim: 10 | 616 | 597 | 1109
prod_dim: 100 | 5976 | 5723 | 7017
prod_dim: 500 | 31110 | 29160 | 32320
Times are in microseconds (us).
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60642
Reviewed By: ngimel
Differential Revision: D29366368
Pulled By: albanD
fbshipit-source-id: b0d692ce030352965c2f152e0f92fbb61fc5ebde