pytorch
5d65b5cd - Add the 3d upsample quantized op for video model (#34594)

Commit
4 years ago
Add the 3d upsample quantized op for video model (#34594) Summary: as title, we are currently missing this 3d op, which is required for video related model. Performance benchmark: ``` import torch, time for dtype in [torch.qint8, torch.quint8, torch.qint32]: print('****', str(dtype), '*****') x = torch.rand(1, 56, 64, 56, 256) q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype) q_x = q_x.permute([0, 4, 1, 2, 3]) x = x.permute([0, 4, 1, 2, 3]) NITER = 100 s = time.time() for i in range(NITER): float_out = torch.nn.functional.interpolate(x, size=30, scale_factor=None, mode="nearest", align_corners=None) time_per_iter_float = (time.time() - s) / NITER s = time.time() for i in range(NITER): quant_out = torch.nn.functional.interpolate(q_x, size=30, scale_factor=None, mode="nearest", align_corners=None) time_per_iter_quant = (time.time() - s) / NITER ref_quantized = torch.quantize_per_tensor(float_out, 0.5, 1, dtype) torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize()) print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t') print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t') bytes_float = (x.numel() + float_out.numel()) * x.element_size() bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size() float_bw_gbps = bytes_float / time_per_iter_float / 1e9 quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9 print('GB/s float', 'GB/s quant', sep='\t') print(float_bw_gbps, quant_bw_gbps, sep='\t') ``` ``` **** torch.qint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 1136.8209528923035 1.294245719909668 0.0011384780660638283 GB/s float GB/s quant 0.20510608588517917 45.03953391792442 **** torch.quint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 827.9890131950378 1.11464262008667 0.0013462046021426 GB/s float GB/s quant 0.28160868355034036 52.29678369508914 **** torch.qint32 ***** time/iter ms (float) time/iter ms (quant) quant/float 834.6958303451538 7.481417655944824 0.008963046638020456 GB/s float GB/s quant 0.2793459455806586 31.16640544920269 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/34594 Differential Revision: D20389106 Pulled By: lly-zero-one fbshipit-source-id: d3a8c2cac58087d8b29e9cae64822f5b2d4c03ba
Author
Parents
Loading