pytorch
bc6eec1d - Factor unnecesary work out of add inner loop (#25751)

Commit
6 years ago
Factor unnecesary work out of add inner loop (#25751) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25751 This PR does several things: 1) Factor unnecessary scale inversion out of quantize function in the inner loop. This saves cycles in the inner kernel (unfortunately the compiler couldn't hoist it out automatically for some reason) 2) Use FMA in the dequantize routine when possible. This also necessitates having the user pass in a pre-multiplied (scale * -zero_point) vector. Benchmark Script ``` import torch import time x = torch.rand(1, 256, 56, 56) y = torch.rand(1, 256, 56, 56) print('dtype', 'ms/iter (float)', 'ms/iter (quant)', 'quant / float', sep='\t') for dtype in [torch.quint8, torch.qint8, torch.qint32]: qX = torch.quantize_linear(x, 0.1, 5, dtype).permute([0, 3, 1, 2]) qY = torch.quantize_linear(y, 0.1, 5, dtype).permute([0, 3, 1, 2]) _x = x.permute([0, 3, 1, 2]) _y = y.permute([0, 3, 1, 2]) NITER = 10000 # Test float s = time.time() for i in range(NITER): _x + _y elapsed_float = time.time() - s ms_per_iter_float = elapsed_float / NITER * 1000 # Test quantized s = time.time() for i in range(NITER): torch.ops.quantized.add(qX, qY, 0.1, 5) elapsed = time.time() - s ms_per_iter = elapsed / NITER * 1000 print(str(dtype), ms_per_iter_float, ms_per_iter, ms_per_iter / ms_per_iter_float, sep='\t') print('float gbps', 'quant gbps', sep='\t') print((x.numel() + 2 * y.numel()) * x.element_size() / ms_per_iter_float / 1e6, (qX.numel() + 2 * qX.numel()) * qX.element_size() / ms_per_iter / 1e6, sep = '\t') ``` Before this change ``` dtype ms/iter (float) ms/iter (quant) quant / float torch.quint8 0.47297704219818115 0.1909616231918335 0.403743958278252 float gbps quant gbps 20.368413560257675 12.612209509659206 torch.qint8 0.4638909578323364 0.18829500675201416 0.40590359344764254 float gbps quant gbps 20.767363185988053 12.79082245219568 torch.qint32 0.4605833768844605 4.219791603088379 9.161840862847583 float gbps quant gbps 20.916499560114787 2.2830018413585225 ``` After this change ``` dtype ms/iter (float) ms/iter (quant) quant / float torch.quint8 0.465389084815979 0.1516613483428955 0.3258807593282176 float gbps quant gbps 20.70051128038237 15.880433784319726 torch.qint8 0.4630591154098511 0.15664465427398683 0.3382821956443757 float gbps quant gbps 20.804669812996085 15.375232631861083 torch.qint32 0.4726278781890869 4.103795266151429 8.682931023610927 float gbps quant gbps 20.38346116380751 2.347532314650444 ``` Test Plan: Imported from OSS Differential Revision: D17222302 Pulled By: jamesr66a fbshipit-source-id: fffc819f565dfd3b85fb6496c7c6635ec2c237a4
Author
James Reed
Parents
Loading