pytorch
03007b3d - Quantized Interpolate Kernel(upsample_bilinear2d) (#26631)

Commit
6 years ago
Quantized Interpolate Kernel(upsample_bilinear2d) (#26631) Summary: We implement the quantized upsample_bilinear2d case for interpolate kernel in this PR. For nhwc performance improvement: import torch, time for dtype in [torch.qint8, torch.quint8, torch.qint32]: print('****', str(dtype), '*****') x = torch.rand(1, 56, 56, 256) q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype) q_x = q_x.permute([0, 3, 1, 2]) x = x.permute([0, 3, 1, 2]) NITER = 100 s = time.time() for i in range(NITER): float_out = torch.nn.functional.interpolate(x, size=5, scale_factor=None, mode="bilinear", align_corners=True) time_per_iter_float = (time.time() - s) / NITER s = time.time() for i in range(NITER): quant_out = torch.nn.quantized.functional.interpolate(q_x, size=5, scale_factor=None, mode="bilinear", align_corners=True) time_per_iter_quant = (time.time() - s) / NITER ref_quantized = torch.quantize_per_tensor(float_out, 0.5, 1, dtype) # torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize()) print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t') print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t') bytes_float = (x.numel() + float_out.numel()) * x.element_size() bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size() float_bw_gbps = bytes_float / time_per_iter_float / 1e9 quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9 print('GB/s float', 'GB/s quant', sep='\t') print(float_bw_gbps, quant_bw_gbps, sep='\t') ===========without nhwc handling=========== **** torch.qint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 1.999044418334961 2.5860953330993652 1.2936657681940702 GB/s float GB/s quant 1.6192056416115257 0.3129103516188541 **** torch.quint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 2.02730655670166 2.6061582565307617 1.2855274639721328 GB/s float GB/s quant 1.596632728927902 0.3105014816242217 **** torch.qint32 ***** time/iter ms (float) time/iter ms (quant) quant/float 2.0180463790893555 2.4047350883483887 1.1916153728010588 GB/s float GB/s quant 1.603959172365819 1.3460376636426636 ===========with nhwc handling=========== **** torch.qint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 2.0913314819335938 0.09696483612060547 0.04636512047863123 GB/s float GB/s quant 1.5477527249803915 8.345458337015 **** torch.quint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 2.1065664291381836 0.09959936141967773 0.04728042754408879 GB/s float GB/s quant 1.5365591871338384 8.124710725706763 **** torch.qint32 ***** time/iter ms (float) time/iter ms (quant) quant/float 2.044203281402588 0.6003522872924805 0.29368521846837126 GB/s float GB/s quant 1.5834354779917448 5.391607675216635 Pull Request resolved: https://github.com/pytorch/pytorch/pull/26631 Differential Revision: D17521498 Pulled By: llyfacebook fbshipit-source-id: 385ae0f77777cd8bee385cafb80e492127b7d103
Author
Parents
Loading