Move rshift to Aten (#31594)
Summary:
VitalyFedyunin , this PR is about move rshift to Aten.
Benchmark script :
```
import timeit
import torch
torch.manual_seed(1)
for n, t in [(10, 100000),(1000, 10000)]:
print('__rshift__ (a.numel() == {}) for {} times'.format(n, t))
for device in ('cpu', 'cuda'):
for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t))
for dtype in ('torch.float32', 'torch.float64'):
print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randn({n}, dtype = {dtype}, device="{device}"); b = torch.randn({n}, dtype = {dtype}, device="{device}")', number=t))
for n, t in [(10, 100000),(1000, 10000)]:
print('__irshift__ (a.numel() == {}) for {} times'.format(n, t))
for device in ('cpu', 'cuda'):
for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
for dtype in ('torch.float32', 'torch.float64'):
print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randn({n}, dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
```
Device: **Tesla P100, skx-8180**
Cuda verison: **9.0.176**
Before:
```
__rshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.17183916084468365
device: cpu, dtype: torch.uint8, 100000 times 0.16587729007005692
device: cpu, dtype: torch.int16, 100000 times 0.16659130714833736
device: cpu, dtype: torch.int32, 100000 times 0.17177579551935196
device: cpu, dtype: torch.int64, 100000 times 0.17860156949609518
device: cpu, dtype: torch.float32, 100000 times 0.23938780091702938
device: cpu, dtype: torch.float64, 100000 times 0.22591270506381989
device: cuda, dtype: torch.int8, 100000 times 1.2709560776129365
device: cuda, dtype: torch.uint8, 100000 times 1.2692269310355186
device: cuda, dtype: torch.int16, 100000 times 1.2785452520474792
device: cuda, dtype: torch.int32, 100000 times 1.2733035255223513
device: cuda, dtype: torch.int64, 100000 times 1.2785427365452051
device: cuda, dtype: torch.float32, 100000 times 1.2980637094005942
device: cuda, dtype: torch.float64, 100000 times 1.3062487514689565
__rshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.03122080024331808
device: cpu, dtype: torch.uint8, 10000 times 0.030290847644209862
device: cpu, dtype: torch.int16, 10000 times 0.024531075730919838
device: cpu, dtype: torch.int32, 10000 times 0.024743229150772095
device: cpu, dtype: torch.int64, 10000 times 0.025563121773302555
device: cpu, dtype: torch.float32, 10000 times 0.6707976600155234
device: cpu, dtype: torch.float64, 10000 times 0.5344798369333148
device: cuda, dtype: torch.int8, 10000 times 0.12768010422587395
device: cuda, dtype: torch.uint8, 10000 times 0.12681372743099928
device: cuda, dtype: torch.int16, 10000 times 0.12995595764368773
device: cuda, dtype: torch.int32, 10000 times 0.12989260721951723
device: cuda, dtype: torch.int64, 10000 times 0.12804713658988476
device: cuda, dtype: torch.float32, 10000 times 0.13013121113181114
device: cuda, dtype: torch.float64, 10000 times 0.1406280631199479
__irshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.3805475188419223
device: cpu, dtype: torch.uint8, 100000 times 0.36341007333248854
device: cpu, dtype: torch.int16, 100000 times 0.36908434610813856
device: cpu, dtype: torch.int32, 100000 times 0.3669992135837674
device: cpu, dtype: torch.int64, 100000 times 0.37847711704671383
device: cpu, dtype: torch.float32, 100000 times 0.4311870699748397
device: cpu, dtype: torch.float64, 100000 times 0.44503832422196865
device: cuda, dtype: torch.int8, 100000 times 1.4343859804794192
device: cuda, dtype: torch.uint8, 100000 times 1.4298221375793219
device: cuda, dtype: torch.int16, 100000 times 1.4460898758843541
device: cuda, dtype: torch.int32, 100000 times 1.4518025070428848
device: cuda, dtype: torch.int64, 100000 times 1.4456725595518947
device: cuda, dtype: torch.float32, 100000 times 1.4610810624435544
device: cuda, dtype: torch.float64, 100000 times 1.4736663019284606
__irshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.05944254994392395
device: cpu, dtype: torch.uint8, 10000 times 0.058085592463612556
device: cpu, dtype: torch.int16, 10000 times 0.05094402376562357
device: cpu, dtype: torch.int32, 10000 times 0.050842881202697754
device: cpu, dtype: torch.int64, 10000 times 0.06223891582340002
device: cpu, dtype: torch.float32, 10000 times 0.7006897022947669
device: cpu, dtype: torch.float64, 10000 times 0.5614962242543697
device: cuda, dtype: torch.int8, 10000 times 0.1461706068366766
device: cuda, dtype: torch.uint8, 10000 times 0.14335164614021778
device: cuda, dtype: torch.int16, 10000 times 0.1448021186515689
device: cuda, dtype: torch.int32, 10000 times 0.14513055887073278
device: cuda, dtype: torch.int64, 10000 times 0.1439579650759697
device: cuda, dtype: torch.float32, 10000 times 0.14666561130434275
device: cuda, dtype: torch.float64, 10000 times 0.1540807681158185
```
After:
```
_rshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.16366520430892706
device: cpu, dtype: torch.uint8, 100000 times 0.16091545950621367
device: cpu, dtype: torch.int16, 100000 times 0.1659633992239833
device: cpu, dtype: torch.int32, 100000 times 0.1682385364547372
device: cpu, dtype: torch.int64, 100000 times 0.17289020214229822
device: cpu, dtype: torch.float32, 100000 times 0.24359441827982664
device: cpu, dtype: torch.float64, 100000 times 0.21783945057541132
device: cuda, dtype: torch.int8, 100000 times 1.2517220517620444
device: cuda, dtype: torch.uint8, 100000 times 1.260181212797761
device: cuda, dtype: torch.int16, 100000 times 1.2681935774162412
device: cuda, dtype: torch.int32, 100000 times 1.2764465296640992
device: cuda, dtype: torch.int64, 100000 times 1.294325228780508
device: cuda, dtype: torch.float32, 100000 times 1.3062216322869062
device: cuda, dtype: torch.float64, 100000 times 1.303224254399538
__rshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.027045012451708317
device: cpu, dtype: torch.uint8, 10000 times 0.026978280395269394
device: cpu, dtype: torch.int16, 10000 times 0.025594274513423443
device: cpu, dtype: torch.int32, 10000 times 0.02593063935637474
device: cpu, dtype: torch.int64, 10000 times 0.02668109256774187
device: cpu, dtype: torch.float32, 10000 times 0.09746317192912102
device: cpu, dtype: torch.float64, 10000 times 0.1644029449671507
device: cuda, dtype: torch.int8, 10000 times 0.12530914042145014
device: cuda, dtype: torch.uint8, 10000 times 0.12615622486919165
device: cuda, dtype: torch.int16, 10000 times 0.12741118855774403
device: cuda, dtype: torch.int32, 10000 times 0.1284919548779726
device: cuda, dtype: torch.int64, 10000 times 0.12974756956100464
device: cuda, dtype: torch.float32, 10000 times 0.13044228963553905
device: cuda, dtype: torch.float64, 10000 times 0.13918257877230644
__irshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.19456563983112574
device: cpu, dtype: torch.uint8, 100000 times 0.190769555978477
device: cpu, dtype: torch.int16, 100000 times 0.2002257639542222
device: cpu, dtype: torch.int32, 100000 times 0.20456529594957829
device: cpu, dtype: torch.int64, 100000 times 0.2043834924697876
device: cpu, dtype: torch.float32, 100000 times 0.2832390898838639
device: cpu, dtype: torch.float64, 100000 times 0.2582795573398471
device: cuda, dtype: torch.int8, 100000 times 1.304957083426416
device: cuda, dtype: torch.uint8, 100000 times 1.3216373259201646
device: cuda, dtype: torch.int16, 100000 times 1.3238621400669217
device: cuda, dtype: torch.int32, 100000 times 1.333009460940957
device: cuda, dtype: torch.int64, 100000 times 1.3835567953065038
device: cuda, dtype: torch.float32, 100000 times 1.4483617274090648
device: cuda, dtype: torch.float64, 100000 times 1.4179155295714736
__irshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.03196091763675213
device: cpu, dtype: torch.uint8, 10000 times 0.03048650734126568
device: cpu, dtype: torch.int16, 10000 times 0.03048624936491251
device: cpu, dtype: torch.int32, 10000 times 0.030591044574975967
device: cpu, dtype: torch.int64, 10000 times 0.031246556900441647
device: cpu, dtype: torch.float32, 10000 times 0.10918692220002413
device: cpu, dtype: torch.float64, 10000 times 0.18057993799448013
device: cuda, dtype: torch.int8, 10000 times 0.13614848721772432
device: cuda, dtype: torch.uint8, 10000 times 0.130373639985919
device: cuda, dtype: torch.int16, 10000 times 0.1332557238638401
device: cuda, dtype: torch.int32, 10000 times 0.1331850504502654
device: cuda, dtype: torch.int64, 10000 times 0.1363008264452219
device: cuda, dtype: torch.float32, 10000 times 0.1370363561436534
device: cuda, dtype: torch.float64, 10000 times 0.1442740885540843
```
Fix https://github.com/pytorch/pytorch/issues/24512 #24516 https://github.com/pytorch/pytorch/issues/24659 https://github.com/pytorch/pytorch/issues/24663
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31594
Differential Revision: D19346542
Pulled By: ezyang
fbshipit-source-id: 37dd00b86898810b850cf4769c3af8aea6d4596b