pytorch
8098ae45 - Move rshift to Aten (#31594)

Commit
4 years ago
Move rshift to Aten (#31594) Summary: VitalyFedyunin , this PR is about move rshift to Aten. Benchmark script : ``` import timeit import torch torch.manual_seed(1) for n, t in [(10, 100000),(1000, 10000)]: print('__rshift__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t)) for dtype in ('torch.float32', 'torch.float64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randn({n}, dtype = {dtype}, device="{device}"); b = torch.randn({n}, dtype = {dtype}, device="{device}")', number=t)) for n, t in [(10, 100000),(1000, 10000)]: print('__irshift__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t)) for dtype in ('torch.float32', 'torch.float64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a >> b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randn({n}, dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t)) ``` Device: **Tesla P100, skx-8180** Cuda verison: **9.0.176** Before: ``` __rshift__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.17183916084468365 device: cpu, dtype: torch.uint8, 100000 times 0.16587729007005692 device: cpu, dtype: torch.int16, 100000 times 0.16659130714833736 device: cpu, dtype: torch.int32, 100000 times 0.17177579551935196 device: cpu, dtype: torch.int64, 100000 times 0.17860156949609518 device: cpu, dtype: torch.float32, 100000 times 0.23938780091702938 device: cpu, dtype: torch.float64, 100000 times 0.22591270506381989 device: cuda, dtype: torch.int8, 100000 times 1.2709560776129365 device: cuda, dtype: torch.uint8, 100000 times 1.2692269310355186 device: cuda, dtype: torch.int16, 100000 times 1.2785452520474792 device: cuda, dtype: torch.int32, 100000 times 1.2733035255223513 device: cuda, dtype: torch.int64, 100000 times 1.2785427365452051 device: cuda, dtype: torch.float32, 100000 times 1.2980637094005942 device: cuda, dtype: torch.float64, 100000 times 1.3062487514689565 __rshift__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.03122080024331808 device: cpu, dtype: torch.uint8, 10000 times 0.030290847644209862 device: cpu, dtype: torch.int16, 10000 times 0.024531075730919838 device: cpu, dtype: torch.int32, 10000 times 0.024743229150772095 device: cpu, dtype: torch.int64, 10000 times 0.025563121773302555 device: cpu, dtype: torch.float32, 10000 times 0.6707976600155234 device: cpu, dtype: torch.float64, 10000 times 0.5344798369333148 device: cuda, dtype: torch.int8, 10000 times 0.12768010422587395 device: cuda, dtype: torch.uint8, 10000 times 0.12681372743099928 device: cuda, dtype: torch.int16, 10000 times 0.12995595764368773 device: cuda, dtype: torch.int32, 10000 times 0.12989260721951723 device: cuda, dtype: torch.int64, 10000 times 0.12804713658988476 device: cuda, dtype: torch.float32, 10000 times 0.13013121113181114 device: cuda, dtype: torch.float64, 10000 times 0.1406280631199479 __irshift__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.3805475188419223 device: cpu, dtype: torch.uint8, 100000 times 0.36341007333248854 device: cpu, dtype: torch.int16, 100000 times 0.36908434610813856 device: cpu, dtype: torch.int32, 100000 times 0.3669992135837674 device: cpu, dtype: torch.int64, 100000 times 0.37847711704671383 device: cpu, dtype: torch.float32, 100000 times 0.4311870699748397 device: cpu, dtype: torch.float64, 100000 times 0.44503832422196865 device: cuda, dtype: torch.int8, 100000 times 1.4343859804794192 device: cuda, dtype: torch.uint8, 100000 times 1.4298221375793219 device: cuda, dtype: torch.int16, 100000 times 1.4460898758843541 device: cuda, dtype: torch.int32, 100000 times 1.4518025070428848 device: cuda, dtype: torch.int64, 100000 times 1.4456725595518947 device: cuda, dtype: torch.float32, 100000 times 1.4610810624435544 device: cuda, dtype: torch.float64, 100000 times 1.4736663019284606 __irshift__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.05944254994392395 device: cpu, dtype: torch.uint8, 10000 times 0.058085592463612556 device: cpu, dtype: torch.int16, 10000 times 0.05094402376562357 device: cpu, dtype: torch.int32, 10000 times 0.050842881202697754 device: cpu, dtype: torch.int64, 10000 times 0.06223891582340002 device: cpu, dtype: torch.float32, 10000 times 0.7006897022947669 device: cpu, dtype: torch.float64, 10000 times 0.5614962242543697 device: cuda, dtype: torch.int8, 10000 times 0.1461706068366766 device: cuda, dtype: torch.uint8, 10000 times 0.14335164614021778 device: cuda, dtype: torch.int16, 10000 times 0.1448021186515689 device: cuda, dtype: torch.int32, 10000 times 0.14513055887073278 device: cuda, dtype: torch.int64, 10000 times 0.1439579650759697 device: cuda, dtype: torch.float32, 10000 times 0.14666561130434275 device: cuda, dtype: torch.float64, 10000 times 0.1540807681158185 ``` After: ``` _rshift__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.16366520430892706 device: cpu, dtype: torch.uint8, 100000 times 0.16091545950621367 device: cpu, dtype: torch.int16, 100000 times 0.1659633992239833 device: cpu, dtype: torch.int32, 100000 times 0.1682385364547372 device: cpu, dtype: torch.int64, 100000 times 0.17289020214229822 device: cpu, dtype: torch.float32, 100000 times 0.24359441827982664 device: cpu, dtype: torch.float64, 100000 times 0.21783945057541132 device: cuda, dtype: torch.int8, 100000 times 1.2517220517620444 device: cuda, dtype: torch.uint8, 100000 times 1.260181212797761 device: cuda, dtype: torch.int16, 100000 times 1.2681935774162412 device: cuda, dtype: torch.int32, 100000 times 1.2764465296640992 device: cuda, dtype: torch.int64, 100000 times 1.294325228780508 device: cuda, dtype: torch.float32, 100000 times 1.3062216322869062 device: cuda, dtype: torch.float64, 100000 times 1.303224254399538 __rshift__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.027045012451708317 device: cpu, dtype: torch.uint8, 10000 times 0.026978280395269394 device: cpu, dtype: torch.int16, 10000 times 0.025594274513423443 device: cpu, dtype: torch.int32, 10000 times 0.02593063935637474 device: cpu, dtype: torch.int64, 10000 times 0.02668109256774187 device: cpu, dtype: torch.float32, 10000 times 0.09746317192912102 device: cpu, dtype: torch.float64, 10000 times 0.1644029449671507 device: cuda, dtype: torch.int8, 10000 times 0.12530914042145014 device: cuda, dtype: torch.uint8, 10000 times 0.12615622486919165 device: cuda, dtype: torch.int16, 10000 times 0.12741118855774403 device: cuda, dtype: torch.int32, 10000 times 0.1284919548779726 device: cuda, dtype: torch.int64, 10000 times 0.12974756956100464 device: cuda, dtype: torch.float32, 10000 times 0.13044228963553905 device: cuda, dtype: torch.float64, 10000 times 0.13918257877230644 __irshift__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.19456563983112574 device: cpu, dtype: torch.uint8, 100000 times 0.190769555978477 device: cpu, dtype: torch.int16, 100000 times 0.2002257639542222 device: cpu, dtype: torch.int32, 100000 times 0.20456529594957829 device: cpu, dtype: torch.int64, 100000 times 0.2043834924697876 device: cpu, dtype: torch.float32, 100000 times 0.2832390898838639 device: cpu, dtype: torch.float64, 100000 times 0.2582795573398471 device: cuda, dtype: torch.int8, 100000 times 1.304957083426416 device: cuda, dtype: torch.uint8, 100000 times 1.3216373259201646 device: cuda, dtype: torch.int16, 100000 times 1.3238621400669217 device: cuda, dtype: torch.int32, 100000 times 1.333009460940957 device: cuda, dtype: torch.int64, 100000 times 1.3835567953065038 device: cuda, dtype: torch.float32, 100000 times 1.4483617274090648 device: cuda, dtype: torch.float64, 100000 times 1.4179155295714736 __irshift__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.03196091763675213 device: cpu, dtype: torch.uint8, 10000 times 0.03048650734126568 device: cpu, dtype: torch.int16, 10000 times 0.03048624936491251 device: cpu, dtype: torch.int32, 10000 times 0.030591044574975967 device: cpu, dtype: torch.int64, 10000 times 0.031246556900441647 device: cpu, dtype: torch.float32, 10000 times 0.10918692220002413 device: cpu, dtype: torch.float64, 10000 times 0.18057993799448013 device: cuda, dtype: torch.int8, 10000 times 0.13614848721772432 device: cuda, dtype: torch.uint8, 10000 times 0.130373639985919 device: cuda, dtype: torch.int16, 10000 times 0.1332557238638401 device: cuda, dtype: torch.int32, 10000 times 0.1331850504502654 device: cuda, dtype: torch.int64, 10000 times 0.1363008264452219 device: cuda, dtype: torch.float32, 10000 times 0.1370363561436534 device: cuda, dtype: torch.float64, 10000 times 0.1442740885540843 ``` Fix https://github.com/pytorch/pytorch/issues/24512 #24516 https://github.com/pytorch/pytorch/issues/24659 https://github.com/pytorch/pytorch/issues/24663 Pull Request resolved: https://github.com/pytorch/pytorch/pull/31594 Differential Revision: D19346542 Pulled By: ezyang fbshipit-source-id: 37dd00b86898810b850cf4769c3af8aea6d4596b
Author
Parents
Loading