Quantized Interpolate Kernel(upsample_bilinear2d) (#26631)
Summary:
We implement the quantized upsample_bilinear2d case for interpolate kernel in this PR.
For nhwc performance improvement:
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 56, 56, 256)
q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 3, 1, 2])
x = x.permute([0, 3, 1, 2])
NITER = 100
s = time.time()
for i in range(NITER):
float_out = torch.nn.functional.interpolate(x, size=5, scale_factor=None, mode="bilinear", align_corners=True)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.nn.quantized.functional.interpolate(q_x, size=5, scale_factor=None, mode="bilinear", align_corners=True)
time_per_iter_quant = (time.time() - s) / NITER
ref_quantized = torch.quantize_per_tensor(float_out, 0.5, 1, dtype)
# torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize())
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
bytes_float = (x.numel() + float_out.numel()) * x.element_size()
bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size()
float_bw_gbps = bytes_float / time_per_iter_float / 1e9
quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9
print('GB/s float', 'GB/s quant', sep='\t')
print(float_bw_gbps, quant_bw_gbps, sep='\t')
===========without nhwc handling===========
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
1.999044418334961 2.5860953330993652 1.2936657681940702
GB/s float GB/s quant
1.6192056416115257 0.3129103516188541
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.02730655670166 2.6061582565307617 1.2855274639721328
GB/s float GB/s quant
1.596632728927902 0.3105014816242217
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.0180463790893555 2.4047350883483887 1.1916153728010588
GB/s float GB/s quant
1.603959172365819 1.3460376636426636
===========with nhwc handling===========
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.0913314819335938 0.09696483612060547 0.04636512047863123
GB/s float GB/s quant
1.5477527249803915 8.345458337015
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.1065664291381836 0.09959936141967773 0.04728042754408879
GB/s float GB/s quant
1.5365591871338384 8.124710725706763
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.044203281402588 0.6003522872924805 0.29368521846837126
GB/s float GB/s quant
1.5834354779917448 5.391607675216635
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26631
Differential Revision: D17521498
Pulled By: llyfacebook
fbshipit-source-id: 385ae0f77777cd8bee385cafb80e492127b7d103