Move cumprod and cumsum to Aten(CPU) (#33280)
Summary:
This PR is about move cumprod and cumsum to Aten.
Test script:
```
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
return time.time()
device = "cpu"
#torch.set_num_threads(1)
#warm up
for n in [10, 300]:
input = torch.randn(n, n, n, requires_grad=False, device=device)
input = input * 0.01 + 1
for dim in range(input.dim()):
for i in range(100):
#output = input.cumsum(dim)
output = input.cumprod(dim)
for n in [10, 300]:
input = torch.randn(n, n, n, requires_grad=False, device=device)
input = input * 0.01 + 1
for dim in range(input.dim()):
fwd_t = 0
for i in range(1000):
t1 = _time()
#output = input.cumsum(dim)
output = input.cumprod(dim)
t2 = _time()
fwd_t = fwd_t + (t2 -t1)
fwd_avg = fwd_t / 1000 * 1000
print("size = (%d, %d, %d); reduce dim=%d; compute time is %.4f(ms)" % (n, n, n, dim, fwd_avg))
```
Test device: **skx-8180**.
Performance:
```
size = (10, 10, 10); reduce dim=0; compute time is 0.0098(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0089(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0089(ms)
size = (300, 300, 300); reduce dim=0; compute time is 208.9403(ms)
size = (300, 300, 300); reduce dim=1; compute time is 241.5989(ms)
size = (300, 300, 300); reduce dim=2; compute time is 66.2587(ms)
After:
size = (10, 10, 10); reduce dim=0; compute time is 0.0065(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0063(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0053(ms)
size = (300, 300, 300); reduce dim=0; compute time is 36.0139(ms)
size = (300, 300, 300); reduce dim=1; compute time is 36.0776(ms)
size = (300, 300, 300); reduce dim=2; compute time is 21.0111(ms)
number_threads = 1:
size = (10, 10, 10); reduce dim=0; compute time is 0.0053(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0052(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0051(ms)
size = (300, 300, 300); reduce dim=0; compute time is 81.8831(ms)
size = (300, 300, 300); reduce dim=1; compute time is 88.5687(ms)
size = (300, 300, 300); reduce dim=2; compute time is 54.9922(ms)
cumprod:
Before:
size = (10, 10, 10); reduce dim=0; compute time is 0.0096(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0088(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0088(ms)
size = (300, 300, 300); reduce dim=0; compute time is 221.2601(ms)
size = (300, 300, 300); reduce dim=1; compute time is 249.7894(ms)
size = (300, 300, 300); reduce dim=2; compute time is 71.5182(ms)
number_threads = 1:
size = (10, 10, 10); reduce dim=0; compute time is 0.0100(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0093(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0093(ms)
size = (300, 300, 300); reduce dim=0; compute time is 207.6287(ms)
size = (300, 300, 300); reduce dim=1; compute time is 241.6693(ms)
size = (300, 300, 300); reduce dim=2; compute time is 66.2977(ms)
After:
size = (10, 10, 10); reduce dim=0; compute time is 0.0063(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0062(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0053(ms)
size = (300, 300, 300); reduce dim=0; compute time is 36.4283(ms)
size = (300, 300, 300); reduce dim=1; compute time is 38.1139(ms)
size = (300, 300, 300); reduce dim=2; compute time is 20.9140(ms)
number_threads =1:
size = (10, 10, 10); reduce dim=0; compute time is 0.0052(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0052(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0050(ms)
size = (300, 300, 300); reduce dim=0; compute time is 82.6926(ms)
size = (300, 300, 300); reduce dim=1; compute time is 90.1265(ms)
size = (300, 300, 300); reduce dim=2; compute time is 55.0196(ms)
```
Fix https://github.com/pytorch/pytorch/issues/24668, https://github.com/pytorch/pytorch/issues/24669.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33280
Differential Revision: D20076997
Pulled By: VitalyFedyunin
fbshipit-source-id: 12225767da8cfdc5e44257462a432bffa04cd469