Restore thrust path for 1d tensors cumulative ops (#39180)
Summary:
Restores thrust path for computing prefix sums for tensors with a single non-degenerate dimension. Benchmark on P100 before:
```
import time
import torch
l = 4000
t=1000
for _ in range(6):
for dtype in (torch.half, torch.float, torch.double):
a = torch.randn(l, device="cuda", dtype=dtype)
print(f'torch.cumsum(a) a.numel() == {l} for {t} times {dtype}')
# dry run
torch.cumsum(a, 0)
torch.cuda.synchronize()
# Iterate
start = time.time()
for _ in range(t):
torch.cumsum(a, 0)
# Final Synchronize Before Teardown
torch.cuda.synchronize()
end = time.time()
elapsed = end - start
bw = t * l * 2 * a.element_size() * 1e-9/elapsed
print(f'Time {elapsed} bandwidth {bw}')
l *= 2
```
```
torch.cumsum(a) a.numel() == 4000 for 1000 times torch.float16
Time 0.29149866104125977 bandwidth 0.05488875984145705
torch.cumsum(a) a.numel() == 4000 for 1000 times torch.float32
Time 0.24511313438415527 bandwidth 0.130551959528402
torch.cumsum(a) a.numel() == 4000 for 1000 times torch.float64
Time 0.25238871574401855 bandwidth 0.25357710550304885
torch.cumsum(a) a.numel() == 8000 for 1000 times torch.float16
Time 0.5812790393829346 bandwidth 0.05505101307965633
torch.cumsum(a) a.numel() == 8000 for 1000 times torch.float32
Time 0.4885847568511963 bandwidth 0.13099057861007293
torch.cumsum(a) a.numel() == 8000 for 1000 times torch.float64
Time 0.5031211376190186 bandwidth 0.2544118909528429
torch.cumsum(a) a.numel() == 16000 for 1000 times torch.float16
Time 1.1607651710510254 bandwidth 0.05513604439220951
torch.cumsum(a) a.numel() == 16000 for 1000 times torch.float32
Time 0.9755356311798096 bandwidth 0.13120996907637011
torch.cumsum(a) a.numel() == 16000 for 1000 times torch.float64
Time 1.0045702457427979 bandwidth 0.25483533987283175
torch.cumsum(a) a.numel() == 32000 for 1000 times torch.float16
Time 2.3198938369750977 bandwidth 0.055174938594129294
torch.cumsum(a) a.numel() == 32000 for 1000 times torch.float32
Time 1.949366569519043 bandwidth 0.13132471029456586
torch.cumsum(a) a.numel() == 32000 for 1000 times torch.float64
Time 2.00749135017395 bandwidth 0.2550446854755488
torch.cumsum(a) a.numel() == 64000 for 1000 times torch.float16
Time 4.63812518119812 bandwidth 0.055194715536735495
torch.cumsum(a) a.numel() == 64000 for 1000 times torch.float32
Time 3.897014856338501 bandwidth 0.13138261435345344
torch.cumsum(a) a.numel() == 64000 for 1000 times torch.float64
Time 4.013219356536865 bandwidth 0.2551567479938705
torch.cumsum(a) a.numel() == 128000 for 1000 times torch.float16
Time 9.274584770202637 bandwidth 0.05520462777427539
torch.cumsum(a) a.numel() == 128000 for 1000 times torch.float32
Time 7.792156934738159 bandwidth 0.1314141910354645
torch.cumsum(a) a.numel() == 128000 for 1000 times torch.float64
Time 8.02474856376648 bandwidth 0.2552104883693396
```
after:
```
torch.cumsum(a) a.numel() == 4000 for 1000 times torch.float16
Time 0.033731937408447266 bandwidth 0.47432792864109924
torch.cumsum(a) a.numel() == 4000 for 1000 times torch.float32
Time 0.031197071075439453 bandwidth 1.025737317539167
torch.cumsum(a) a.numel() == 4000 for 1000 times torch.float64
Time 0.03245425224304199 bandwidth 1.972006611667389
torch.cumsum(a) a.numel() == 8000 for 1000 times torch.float16
Time 0.034340858459472656 bandwidth 0.931834596906329
torch.cumsum(a) a.numel() == 8000 for 1000 times torch.float32
Time 0.031183481216430664 bandwidth 2.0523686741645197
torch.cumsum(a) a.numel() == 8000 for 1000 times torch.float64
Time 0.031975507736206055 bandwidth 4.003063878015136
torch.cumsum(a) a.numel() == 16000 for 1000 times torch.float16
Time 0.032624006271362305 bandwidth 1.9617455767895642
torch.cumsum(a) a.numel() == 16000 for 1000 times torch.float32
Time 0.03129267692565918 bandwidth 4.0904138787514
torch.cumsum(a) a.numel() == 16000 for 1000 times torch.float64
Time 0.03260397911071777 bandwidth 7.851802356107085
torch.cumsum(a) a.numel() == 32000 for 1000 times torch.float16
Time 0.032918691635131836 bandwidth 3.888368390176069
torch.cumsum(a) a.numel() == 32000 for 1000 times torch.float32
Time 0.030851364135742188 bandwidth 8.29785026275116
torch.cumsum(a) a.numel() == 32000 for 1000 times torch.float64
Time 0.037447452545166016 bandwidth 13.6724921243299
torch.cumsum(a) a.numel() == 64000 for 1000 times torch.float16
Time 0.03391098976135254 bandwidth 7.549175114073387
torch.cumsum(a) a.numel() == 64000 for 1000 times torch.float32
Time 0.03214144706726074 bandwidth 15.929587704267457
torch.cumsum(a) a.numel() == 64000 for 1000 times torch.float64
Time 0.034329891204833984 bandwidth 29.828233182859922
torch.cumsum(a) a.numel() == 128000 for 1000 times torch.float16
Time 0.03589606285095215 bandwidth 14.263402705915954
torch.cumsum(a) a.numel() == 128000 for 1000 times torch.float32
Time 0.033178091049194336 bandwidth 30.863740728231736
torch.cumsum(a) a.numel() == 128000 for 1000 times torch.float64
Time 0.03487515449523926 bandwidth 58.72375419238841
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39180
Differential Revision: D21824498
Pulled By: ngimel
fbshipit-source-id: b50fadde598e9ce2871201cd6bb22fa6ac0d482e