Add op bitwise_and (#31104)
Summary:
Refer to https://github.com/pytorch/pytorch/pull/25665, add `bitwise_and` operator.
Benchmark script :
```
import timeit
#for __and__
for n, t in [(10, 100000),(1000, 10000)]:
print('__and__ (a.numel() == {}) for {} times'.format(n, t))
for device in ('cpu', 'cuda'):
for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t))
#for __iand__
for n, t in [(10, 100000),(1000, 10000)]:
print('__iand__ (a.numel() == {}) for {} times'.format(n, t))
for device in ('cpu', 'cuda'):
for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
```
Device: **Tesla P100, skx-8180**
Cuda verison: **9.0.176**
Before:
```
__and__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.1766007635742426
device: cpu, dtype: torch.uint8, 100000 times 0.17322628945112228
device: cpu, dtype: torch.int16, 100000 times 0.17650844901800156
device: cpu, dtype: torch.int32, 100000 times 0.17711848113685846
device: cpu, dtype: torch.int64, 100000 times 0.18240160401910543
device: cuda, dtype: torch.int8, 100000 times 1.273967768996954
device: cuda, dtype: torch.uint8, 100000 times 1.2778537990525365
device: cuda, dtype: torch.int16, 100000 times 1.2753686187788844
device: cuda, dtype: torch.int32, 100000 times 1.2797665279358625
device: cuda, dtype: torch.int64, 100000 times 1.2933144550770521
__and__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.031139614060521126
device: cpu, dtype: torch.uint8, 10000 times 0.03091452084481716
device: cpu, dtype: torch.int16, 10000 times 0.022756479680538177
device: cpu, dtype: torch.int32, 10000 times 0.025045674294233322
device: cpu, dtype: torch.int64, 10000 times 0.024164282716810703
device: cuda, dtype: torch.int8, 10000 times 0.12820732593536377
device: cuda, dtype: torch.uint8, 10000 times 0.12775669433176517
device: cuda, dtype: torch.int16, 10000 times 0.12697868794202805
device: cuda, dtype: torch.int32, 10000 times 0.12832533661276102
device: cuda, dtype: torch.int64, 10000 times 0.1280576130375266
__iand__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.3687064303085208
device: cpu, dtype: torch.uint8, 100000 times 0.36253443732857704
device: cpu, dtype: torch.int16, 100000 times 0.362891579978168
device: cpu, dtype: torch.int32, 100000 times 0.37680106051266193
device: cpu, dtype: torch.int64, 100000 times 0.3689364707097411
device: cuda, dtype: torch.int8, 100000 times 1.419940729625523
device: cuda, dtype: torch.uint8, 100000 times 1.4247053815051913
device: cuda, dtype: torch.int16, 100000 times 1.4191444097086787
device: cuda, dtype: torch.int32, 100000 times 1.4305962566286325
device: cuda, dtype: torch.int64, 100000 times 1.4567416654899716
__iand__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.06224383972585201
device: cpu, dtype: torch.uint8, 10000 times 0.06205617543309927
device: cpu, dtype: torch.int16, 10000 times 0.05016433447599411
device: cpu, dtype: torch.int32, 10000 times 0.05216377507895231
device: cpu, dtype: torch.int64, 10000 times 0.06139362137764692
device: cuda, dtype: torch.int8, 10000 times 0.14827249851077795
device: cuda, dtype: torch.uint8, 10000 times 0.14801877550780773
device: cuda, dtype: torch.int16, 10000 times 0.14952312968671322
device: cuda, dtype: torch.int32, 10000 times 0.14999118447303772
device: cuda, dtype: torch.int64, 10000 times 0.14951884001493454
```
After:
```
__and__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.23157884553074837
device: cpu, dtype: torch.uint8, 100000 times 0.23063660878688097
device: cpu, dtype: torch.int16, 100000 times 0.23005440644919872
device: cpu, dtype: torch.int32, 100000 times 0.23748818412423134
device: cpu, dtype: torch.int64, 100000 times 0.24106105230748653
device: cuda, dtype: torch.int8, 100000 times 1.4394256137311459
device: cuda, dtype: torch.uint8, 100000 times 1.4436759827658534
device: cuda, dtype: torch.int16, 100000 times 1.4631587155163288
device: cuda, dtype: torch.int32, 100000 times 1.459101552143693
device: cuda, dtype: torch.int64, 100000 times 1.4784048134461045
__and__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.028442862443625927
device: cpu, dtype: torch.uint8, 10000 times 0.028130197897553444
device: cpu, dtype: torch.int16, 10000 times 0.025318274274468422
device: cpu, dtype: torch.int32, 10000 times 0.02519288007169962
device: cpu, dtype: torch.int64, 10000 times 0.028299466706812382
device: cuda, dtype: torch.int8, 10000 times 0.14342594426125288
device: cuda, dtype: torch.uint8, 10000 times 0.145280827768147
device: cuda, dtype: torch.int16, 10000 times 0.14673697855323553
device: cuda, dtype: torch.int32, 10000 times 0.14499565307050943
device: cuda, dtype: torch.int64, 10000 times 0.14582364354282618
__iand__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times 0.25548241566866636
device: cpu, dtype: torch.uint8, 100000 times 0.2552562616765499
device: cpu, dtype: torch.int16, 100000 times 0.25905191246420145
device: cpu, dtype: torch.int32, 100000 times 0.26635489892214537
device: cpu, dtype: torch.int64, 100000 times 0.26269810926169157
device: cuda, dtype: torch.int8, 100000 times 1.485458506271243
device: cuda, dtype: torch.uint8, 100000 times 1.4742380809038877
device: cuda, dtype: torch.int16, 100000 times 1.507783885113895
device: cuda, dtype: torch.int32, 100000 times 1.4926990242674947
device: cuda, dtype: torch.int64, 100000 times 1.519851053133607
__iand__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times 0.03425929415971041
device: cpu, dtype: torch.uint8, 10000 times 0.03293587639927864
device: cpu, dtype: torch.int16, 10000 times 0.029559112153947353
device: cpu, dtype: torch.int32, 10000 times 0.030915481969714165
device: cpu, dtype: torch.int64, 10000 times 0.03292469773441553
device: cuda, dtype: torch.int8, 10000 times 0.15792148280888796
device: cuda, dtype: torch.uint8, 10000 times 0.16000914946198463
device: cuda, dtype: torch.int16, 10000 times 0.1600684942677617
device: cuda, dtype: torch.int32, 10000 times 0.16162546630948782
device: cuda, dtype: torch.int64, 10000 times 0.1629159888252616
```
Fix https://github.com/pytorch/pytorch/issues/24508, https://github.com/pytorch/pytorch/issues/24509, https://github.com/pytorch/pytorch/issues/24655, https://github.com/pytorch/pytorch/issues/24656.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31104
Differential Revision: D18938930
Pulled By: VitalyFedyunin
fbshipit-source-id: a77e805a0b84e8ace16c6e648c2f67dad44f2e44