pytorch
dfc8247d - Faster cumsum and cumprod backwards (#60642)

Commit
3 years ago
Faster cumsum and cumprod backwards (#60642) Summary: Piggybacking on https://github.com/pytorch/pytorch/pull/58747, now we can implement the backwards of `cumsum` and `cumprod` without tricks. This minimises the number of kernels that are launched in GPU, so we see a reasonable speed-up on GPU. We should also get a better stability for ill-conditioned inputs, as we do not perform any numerical tricks to get the result. Note that the benchmarks test forward + backward, so the true speed-up on the backward should be even faster. Even more so in `cumsum`, as it requires less operations than the backward of `cumprod`. <details> <summary> Test Script </summary> ```python from itertools import product import torch from torch.utils.benchmark import Compare, Timer def get_timer(ndims, prod_dim, dim, num_threads, device): size = [500]*ndims size[dim] = prod_dim x = torch.rand(*size, device=device, requires_grad=True) # Make sure there are no zeros as the formula for the backward # that we are testing is for when the backward has no zeros with torch.no_grad(): x.add_(1e-3) grad = torch.ones_like(x) timer = Timer( "torch.autograd.grad([x.cumprod(dim)], [x], grad_outputs=[grad])", globals={"x": x, "dim": dim, "grad": grad}, label=f"Cumprod + Backwards {device}", description=f"dim: {dim}", sub_label=f"prod_dim: {prod_dim}", num_threads=num_threads, ) return timer.blocked_autorange(min_run_time=5) def get_params(): ndims = 3 dims = range(ndims) prod_dims = [10, 100, 500] for dim, prod_dim, device in product(dims, prod_dims, ("cpu", "cuda")): threads = (1, 2, 4) if device == "cpu" else (1,) for num_threads in threads: yield ndims, prod_dim, dim, num_threads, device compare = Compare([get_timer(*params) for params in get_params()]) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary> Benchmark PR </summary> ``` [------------ Cumprod + Backwards cpu -------------] | dim: 0 | dim: 1 | dim: 2 1 threads: ----------------------------------------- prod_dim: 10 | 11 | 14 | 12 prod_dim: 100 | 260 | 270 | 260 prod_dim: 500 | 1400 | 1550 | 1360 2 threads: ----------------------------------------- prod_dim: 10 | 6 | 6 | 6 prod_dim: 100 | 170 | 166 | 167 prod_dim: 500 | 902 | 950 | 858 4 threads: ----------------------------------------- prod_dim: 10 | 4 | 3 | 3 prod_dim: 100 | 110 | 108 | 106 prod_dim: 500 | 576 | 590 | 547 Times are in milliseconds (ms). [------------ Cumprod + Backwards cuda ------------] | dim: 0 | dim: 1 | dim: 2 1 threads: ----------------------------------------- prod_dim: 10 | 562 | 566 | 1075 prod_dim: 100 | 5388 | 5394 | 6697 prod_dim: 500 | 28170 | 27580 | 30740 Times are in microseconds (us). ``` </details> <details> <summary> Benchmark master </summary> ``` [------------ Cumprod + Backwards cpu -------------] | dim: 0 | dim: 1 | dim: 2 1 threads: ----------------------------------------- prod_dim: 10 | 11 | 13 | 12 prod_dim: 100 | 270 | 270 | 256 prod_dim: 500 | 1500 | 1590 | 1300 2 threads: ----------------------------------------- prod_dim: 10 | 6 | 6 | 6 prod_dim: 100 | 170 | 170 | 164 prod_dim: 500 | 911 | 940 | 840 4 threads: ----------------------------------------- prod_dim: 10 | 4 | 4 | 4 prod_dim: 100 | 111 | 109 | 105 prod_dim: 500 | 570 | 590 | 536 Times are in milliseconds (ms). [------------ Cumprod + Backwards cuda ------------] | dim: 0 | dim: 1 | dim: 2 1 threads: ----------------------------------------- prod_dim: 10 | 616 | 597 | 1109 prod_dim: 100 | 5976 | 5723 | 7017 prod_dim: 500 | 31110 | 29160 | 32320 Times are in microseconds (us). ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/60642 Reviewed By: ngimel Differential Revision: D29366368 Pulled By: albanD fbshipit-source-id: b0d692ce030352965c2f152e0f92fbb61fc5ebde
Author
Parents
Loading