Quantized Interpolate Kernel(upsample_nearest2d) (#26617)
Summary:
In this PR, we implemented the support of quantized interpolate with upsample_nearest2d case.
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 56, 56, 256)
q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 3, 1, 2])
x = x.permute([0, 3, 1, 2])
NITER = 100
s = time.time()
for i in range(NITER):
# float_out = torch.nn.functional.avg_pool2d(x, kernel_size=5, stride=None, padding=0)
# float_out = torch.nn.functional.adaptive_avg_pool2d(x, output_size=5)
float_out = torch.nn.functional.interpolate(x, size=5, scale_factor=None, mode="nearest", align_corners=None)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
# quant_out = torch.nn.quantized.functional.avg_pool2d(q_x, kernel_size=5, stride=None, padding=0)
# quant_out = torch.nn.quantized.functional.adaptive_avg_pool2d(q_x, output_size=5)
quant_out = torch.nn.quantized.functional.interpolate(q_x, size=5, scale_factor=None, mode="nearest", align_corners=None)
time_per_iter_quant = (time.time() - s) / NITER
ref_quantized = torch.quantize_per_tensor(float_out, 0.5, 1, dtype)
# torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize())
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
bytes_float = (x.numel() + float_out.numel()) * x.element_size()
bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size()
float_bw_gbps = bytes_float / time_per_iter_float / 1e9
quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9
print('GB/s float', 'GB/s quant', sep='\t')
print(float_bw_gbps, quant_bw_gbps, sep='\t')
=========without special handling of NHWC layout=============
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.08712100982666 2.1624231338500977 1.0360794240817361
GB/s float GB/s quant
1.5508750976872339 0.37421723220248165
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.056601047515869 2.184889316558838 1.0623787823107091
GB/s float GB/s quant
1.573890086222483 0.3703693335250963
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.0152783393859863 2.067704200744629 1.0260142037623525
GB/s float GB/s quant
1.6061622539873104 1.5654386148823074
=========with special handling of NHWC layout=============
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.044649124145508 0.009250640869140625 0.004524317038018256
GB/s float GB/s quant
1.5830902044636819 87.47675014597938
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.049403190612793 0.009107589721679688 0.004444020465761265
GB/s float GB/s quant
1.579417859221808 88.8507305147644
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.0601415634155273 0.01062631607055664 0.0051580513976618066
GB/s float GB/s quant
1.5711852318699757 304.6082930818039
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26617
Differential Revision: D17519146
Pulled By: llyfacebook
fbshipit-source-id: 126876e550ef7009fd75f5ccc033599f1f37456d