Vectorized quantized relu/relu6 (#25496)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25496
Benchmark Script
```
import torch, time
sizes = [
(1, 56, 56, 256),
(1, 28, 28, 512),
(1, 14, 14, 1024),
(1, 7, 7, 2048),
]
NITER = 1000
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('*****', str(dtype), '*****')
print('\t*****relu*****')
print('\tsize',
'time (float ms)',
'time (quant ms)',
'quant / float',
sep='\t')
for size in sizes:
# NHWC
x = torch.rand(*size)
# NCHW
x = x.permute([0, 2, 3, 1])
# Test float
s = time.time()
for i in range(NITER):
torch.relu(x)
time_per_iter_float = (time.time() - s) / NITER
# Test quantized
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
s = time.time()
for i in range(NITER):
torch.relu(q_x)
time_per_iter_quant = (time.time() - s) / NITER
print('\t',
size,
time_per_iter_float * 1000,
time_per_iter_quant * 1000,
time_per_iter_quant / time_per_iter_float,
sep='\t')
print('\t*****relu6*****')
print('\tsize',
'time (float ms)',
'time (quant ms)',
'quant / float',
sep='\t')
for size in sizes:
# NHWC
x = torch.rand(*size)
# NCHW
x = x.permute([0, 2, 3, 1])
# Test float relu6
s = time.time()
for i in range(NITER):
torch._C._nn.hardtanh(x, 0., 6.)
time_per_iter_float_6 = (time.time() - s) / NITER
# Test quantized relu6
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
s = time.time()
for i in range(NITER):
torch.ops.quantized.relu6(q_x)
time_per_iter_quant_6 = (time.time() - s) / NITER
print('\t',
size,
time_per_iter_float_6 * 1000,
time_per_iter_quant_6 * 1000,
time_per_iter_quant_6 / time_per_iter_float_6,
sep='\t')
```
Before this change (AVX2)
```
$ OMP_NUM_THREADS=1 python relu_bench.py
***** torch.qint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.28845906257629395 0.32473158836364746 1.1257458353479874
(1, 28, 28, 512) 0.12658190727233887 0.1621997356414795 1.2813816692816096
(1, 14, 14, 1024) 0.060466766357421875 0.08151435852050781 1.3480852943031985
(1, 7, 7, 2048) 0.021933555603027344 0.04172706604003906 1.9024305404582809
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.0264298915863037 0.4686436653137207 0.45657640054641424
(1, 28, 28, 512) 0.4577608108520508 0.23253798484802246 0.5079901541051298
(1, 14, 14, 1024) 0.22967290878295898 0.11695981025695801 0.509245129853278
(1, 7, 7, 2048) 0.12731575965881348 0.060141801834106445 0.4723830105187069
***** torch.quint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.28515172004699707 0.32268643379211426 1.1316306762551913
(1, 28, 28, 512) 0.1268613338470459 0.1618938446044922 1.2761480562681475
(1, 14, 14, 1024) 0.06022787094116211 0.08164644241333008 1.355625578946535
(1, 7, 7, 2048) 0.018331527709960938 0.04460000991821289 2.432967433149516
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.027123212814331 0.5206699371337891 0.50692062124382
(1, 28, 28, 512) 0.4589383602142334 0.25958728790283203 0.565625605542444
(1, 14, 14, 1024) 0.23261427879333496 0.13058066368103027 0.561361341867771
(1, 7, 7, 2048) 0.13072657585144043 0.06684517860412598 0.5113358027528374
***** torch.qint32 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.285900354385376 0.44794583320617676 1.5667900593168678
(1, 28, 28, 512) 0.12691712379455566 0.21081137657165527 1.6610160258035915
(1, 14, 14, 1024) 0.05957603454589844 0.10731720924377441 1.8013486473507283
(1, 7, 7, 2048) 0.01675701141357422 0.05678510665893555 3.388737123669683
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.0314903259277344 0.6447939872741699 0.6251090980366052
(1, 28, 28, 512) 0.4572310447692871 0.3106963634490967 0.6795172090859886
(1, 14, 14, 1024) 0.2294166088104248 0.1586904525756836 0.6917130080447454
(1, 7, 7, 2048) 0.12760710716247559 0.07992196083068848 0.6263127705647926
```
After this change (AVX2)
```
$ OMP_NUM_THREADS=1 python relu_bench.py
***** torch.qint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.2889232635498047 0.06460881233215332 0.22361928056034167
(1, 28, 28, 512) 0.13853216171264648 0.013955354690551758 0.10073729102343015
(1, 14, 14, 1024) 0.0721442699432373 0.007253408432006836 0.10054032617855548
(1, 7, 7, 2048) 0.015225648880004883 0.004289150238037109 0.28170557930505313
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.042311191558838 0.06422209739685059 0.061615089540392104
(1, 28, 28, 512) 0.46384429931640625 0.01335287094116211 0.028787399049295198
(1, 14, 14, 1024) 0.2301616668701172 0.007760286331176758 0.033716675920477994
(1, 7, 7, 2048) 0.12573981285095215 0.004631757736206055 0.03683604763827976
***** torch.quint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.2877991199493408 0.0571134090423584 0.1984488661828141
(1, 28, 28, 512) 0.12664175033569336 0.013076543807983398 0.10325618347283565
(1, 14, 14, 1024) 0.06389951705932617 0.005294084548950195 0.08285014961904974
(1, 7, 7, 2048) 0.016280174255371094 0.003660917282104492 0.22486966199988284
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.0244698524475098 0.05978655815124512 0.05835853344870231
(1, 28, 28, 512) 0.454937219619751 0.013289213180541992 0.02921109244842504
(1, 14, 14, 1024) 0.22972846031188965 0.0077877044677734375 0.03389960676705229
(1, 7, 7, 2048) 0.125657320022583 0.0045795440673828125 0.03644470586003093
***** torch.qint32 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.28399205207824707 0.2665698528289795 0.9386525111468004
(1, 28, 28, 512) 0.12665152549743652 0.12166023254394531 0.9605903447756557
(1, 14, 14, 1024) 0.0598299503326416 0.059305429458618164 0.9912331387355795
(1, 7, 7, 2048) 0.014290809631347656 0.012906551361083984 0.9031364698031366
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.020923376083374 0.27229976654052734 0.2667191024513184
(1, 28, 28, 512) 0.4564201831817627 0.12390279769897462 0.2714665176181136
(1, 14, 14, 1024) 0.23244047164916992 0.05935955047607422 0.25537527976482316
(1, 7, 7, 2048) 0.1271505355834961 0.014976024627685547 0.11778184463762029
```
Test Plan: Imported from OSS
Differential Revision: D17141891
Pulled By: jamesr66a
fbshipit-source-id: 14b8c3330017c518a6b385780a449ca51efef0ce