Factor unnecesary work out of add inner loop (#25751)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25751
This PR does several things:
1) Factor unnecessary scale inversion out of quantize function in the inner loop. This saves cycles in the inner kernel (unfortunately the compiler couldn't hoist it out automatically for some reason)
2) Use FMA in the dequantize routine when possible. This also necessitates having the user pass in a pre-multiplied (scale * -zero_point) vector.
Benchmark Script
```
import torch
import time
x = torch.rand(1, 256, 56, 56)
y = torch.rand(1, 256, 56, 56)
print('dtype', 'ms/iter (float)', 'ms/iter (quant)', 'quant / float', sep='\t')
for dtype in [torch.quint8, torch.qint8, torch.qint32]:
qX = torch.quantize_linear(x, 0.1, 5, dtype).permute([0, 3, 1, 2])
qY = torch.quantize_linear(y, 0.1, 5, dtype).permute([0, 3, 1, 2])
_x = x.permute([0, 3, 1, 2])
_y = y.permute([0, 3, 1, 2])
NITER = 10000
# Test float
s = time.time()
for i in range(NITER):
_x + _y
elapsed_float = time.time() - s
ms_per_iter_float = elapsed_float / NITER * 1000
# Test quantized
s = time.time()
for i in range(NITER):
torch.ops.quantized.add(qX, qY, 0.1, 5)
elapsed = time.time() - s
ms_per_iter = elapsed / NITER * 1000
print(str(dtype), ms_per_iter_float, ms_per_iter, ms_per_iter / ms_per_iter_float, sep='\t')
print('float gbps', 'quant gbps', sep='\t')
print((x.numel() + 2 * y.numel()) * x.element_size() / ms_per_iter_float / 1e6,
(qX.numel() + 2 * qX.numel()) * qX.element_size() / ms_per_iter / 1e6,
sep = '\t')
```
Before this change
```
dtype ms/iter (float) ms/iter (quant) quant / float
torch.quint8 0.47297704219818115 0.1909616231918335 0.403743958278252
float gbps quant gbps
20.368413560257675 12.612209509659206
torch.qint8 0.4638909578323364 0.18829500675201416 0.40590359344764254
float gbps quant gbps
20.767363185988053 12.79082245219568
torch.qint32 0.4605833768844605 4.219791603088379 9.161840862847583
float gbps quant gbps
20.916499560114787 2.2830018413585225
```
After this change
```
dtype ms/iter (float) ms/iter (quant) quant / float
torch.quint8 0.465389084815979 0.1516613483428955 0.3258807593282176
float gbps quant gbps
20.70051128038237 15.880433784319726
torch.qint8 0.4630591154098511 0.15664465427398683 0.3382821956443757
float gbps quant gbps
20.804669812996085 15.375232631861083
torch.qint32 0.4726278781890869 4.103795266151429 8.682931023610927
float gbps quant gbps
20.38346116380751 2.347532314650444
```
Test Plan: Imported from OSS
Differential Revision: D17222302
Pulled By: jamesr66a
fbshipit-source-id: fffc819f565dfd3b85fb6496c7c6635ec2c237a4