pytorch
9ba6a768 - Add op bitwise_or (#31559)

Commit
4 years ago
Add op bitwise_or (#31559) Summary: ezyang , this PR add bitwise_or operator as https://github.com/pytorch/pytorch/pull/31104 . Benchmark script : ``` import timeit import torch torch.manual_seed(1) for n, t in [(10, 100000),(1000, 10000)]: print('__or__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a | b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t)) for n, t in [(10, 100000),(1000, 10000)]: print('__ior__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a | b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t)) ``` Device: **Tesla P100, skx-8180** Cuda verison: **9.0.176** Before: ``` __or__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.17616272252053022 device: cpu, dtype: torch.uint8, 100000 times 0.17148233391344547 device: cpu, dtype: torch.int16, 100000 times 0.17616403382271528 device: cpu, dtype: torch.int32, 100000 times 0.17717823758721352 device: cpu, dtype: torch.int64, 100000 times 0.1801931718364358 device: cuda, dtype: torch.int8, 100000 times 1.270583058707416 device: cuda, dtype: torch.uint8, 100000 times 1.2636413089931011 device: cuda, dtype: torch.int16, 100000 times 1.2839747751131654 device: cuda, dtype: torch.int32, 100000 times 1.2548385225236416 device: cuda, dtype: torch.int64, 100000 times 1.2650810535997152 __or__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.031136621721088886 device: cpu, dtype: torch.uint8, 10000 times 0.030786747112870216 device: cpu, dtype: torch.int16, 10000 times 0.02391665056347847 device: cpu, dtype: torch.int32, 10000 times 0.024147341027855873 device: cpu, dtype: torch.int64, 10000 times 0.024414129555225372 device: cuda, dtype: torch.int8, 10000 times 0.12741921469569206 device: cuda, dtype: torch.uint8, 10000 times 0.1249831635504961 device: cuda, dtype: torch.int16, 10000 times 0.1283819805830717 device: cuda, dtype: torch.int32, 10000 times 0.12591975275427103 device: cuda, dtype: torch.int64, 10000 times 0.12655890546739101 __ior__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.3908365070819855 device: cpu, dtype: torch.uint8, 100000 times 0.38267823681235313 device: cpu, dtype: torch.int16, 100000 times 0.38239253498613834 device: cpu, dtype: torch.int32, 100000 times 0.3817988149821758 device: cpu, dtype: torch.int64, 100000 times 0.3901665909215808 device: cuda, dtype: torch.int8, 100000 times 1.4211318120360374 device: cuda, dtype: torch.uint8, 100000 times 1.4215159295126796 device: cuda, dtype: torch.int16, 100000 times 1.4307750314474106 device: cuda, dtype: torch.int32, 100000 times 1.4123614141717553 device: cuda, dtype: torch.int64, 100000 times 1.4480243818834424 __ior__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.06468924414366484 device: cpu, dtype: torch.uint8, 10000 times 0.06442475505173206 device: cpu, dtype: torch.int16, 10000 times 0.05267547257244587 device: cpu, dtype: torch.int32, 10000 times 0.05286940559744835 device: cpu, dtype: torch.int64, 10000 times 0.06211103219538927 device: cuda, dtype: torch.int8, 10000 times 0.15332304500043392 device: cuda, dtype: torch.uint8, 10000 times 0.15353196952492 device: cuda, dtype: torch.int16, 10000 times 0.15300503931939602 device: cuda, dtype: torch.int32, 10000 times 0.15274472255259752 device: cuda, dtype: torch.int64, 10000 times 0.1512152962386608 ``` After: ``` __or__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.2465507509186864 device: cpu, dtype: torch.uint8, 100000 times 0.2472386620938778 device: cpu, dtype: torch.int16, 100000 times 0.2469814233481884 device: cpu, dtype: torch.int32, 100000 times 0.2535214088857174 device: cpu, dtype: torch.int64, 100000 times 0.24855613708496094 device: cuda, dtype: torch.int8, 100000 times 1.4351346511393785 device: cuda, dtype: torch.uint8, 100000 times 1.4434308474883437 device: cuda, dtype: torch.int16, 100000 times 1.4520929995924234 device: cuda, dtype: torch.int32, 100000 times 1.4456610176712275 device: cuda, dtype: torch.int64, 100000 times 1.4580101007595658 __or__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.029985425993800163 device: cpu, dtype: torch.uint8, 10000 times 0.03024935908615589 device: cpu, dtype: torch.int16, 10000 times 0.026356655173003674 device: cpu, dtype: torch.int32, 10000 times 0.027377349324524403 device: cpu, dtype: torch.int64, 10000 times 0.029163731262087822 device: cuda, dtype: torch.int8, 10000 times 0.14540370367467403 device: cuda, dtype: torch.uint8, 10000 times 0.1456305105239153 device: cuda, dtype: torch.int16, 10000 times 0.1450125053524971 device: cuda, dtype: torch.int32, 10000 times 0.1472016740590334 device: cuda, dtype: torch.int64, 10000 times 0.14709716010838747 __ior__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.27195510920137167 device: cpu, dtype: torch.uint8, 100000 times 0.2692424338310957 device: cpu, dtype: torch.int16, 100000 times 0.27726674638688564 device: cpu, dtype: torch.int32, 100000 times 0.2815811652690172 device: cpu, dtype: torch.int64, 100000 times 0.2852728571742773 device: cuda, dtype: torch.int8, 100000 times 1.4743850827217102 device: cuda, dtype: torch.uint8, 100000 times 1.4766502184793353 device: cuda, dtype: torch.int16, 100000 times 1.4774163831025362 device: cuda, dtype: torch.int32, 100000 times 1.4749693805351853 device: cuda, dtype: torch.int64, 100000 times 1.5772947426885366 __ior__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.03614502027630806 device: cpu, dtype: torch.uint8, 10000 times 0.03619729354977608 device: cpu, dtype: torch.int16, 10000 times 0.0319912089034915 device: cpu, dtype: torch.int32, 10000 times 0.03319283854216337 device: cpu, dtype: torch.int64, 10000 times 0.0343862259760499 device: cuda, dtype: torch.int8, 10000 times 0.1581476852297783 device: cuda, dtype: torch.uint8, 10000 times 0.15974601730704308 device: cuda, dtype: torch.int16, 10000 times 0.15957212820649147 device: cuda, dtype: torch.int32, 10000 times 0.16002820804715157 device: cuda, dtype: torch.int64, 10000 times 0.16129320487380028 ``` Fix https://github.com/pytorch/pytorch/issues/24511, https://github.com/pytorch/pytorch/issues/24515, https://github.com/pytorch/pytorch/issues/24658, https://github.com/pytorch/pytorch/issues/24662. Pull Request resolved: https://github.com/pytorch/pytorch/pull/31559 Differential Revision: D19315875 Pulled By: ezyang fbshipit-source-id: 4a3ca88fdafbeb796079687e676228111eb44aad
Author
Parents
Loading