pytorch
f7bcba33 - Vectorized specialization of max_pool2d for channels-last layout (#25676)

Commit
5 years ago
Vectorized specialization of max_pool2d for channels-last layout (#25676) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25676 This PR achieves two things: 1) Ensures the channels-last layout is propagated through the operator if we receive an input in that layout. This helps to alleviate unnecessary data movement in, e.g. ResNet inference 2) Applies interleaved vectorization along the channel dimension in the kernel. This allows us to use the functional units on the CPU much more effectively. Benchmark script ``` import torch, time for dtype in [torch.qint8, torch.quint8, torch.qint32]: print('****', str(dtype), '*****') x = torch.rand(1, 56, 56, 256) q_x = torch.quantize_linear(x, 0.5, 1, dtype) q_x = q_x.permute([0, 3, 1, 2]) x = x.permute([0, 3, 1, 2]) NITER = 100 s = time.time() for i in range(NITER): float_out = torch.max_pool2d(x, kernel_size=3, stride=None, padding=0, dilation=1) time_per_iter_float = (time.time() - s) / NITER s = time.time() for i in range(NITER): quant_out = torch.max_pool2d(q_x, kernel_size=3, stride=None, padding=0, dilation=1) time_per_iter_quant = (time.time() - s) / NITER ref_quantized = torch.quantize_linear(float_out, 0.5, 1, dtype) torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize()) print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t') print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t') bytes_float = (x.numel() + float_out.numel()) * x.element_size() bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size() float_bw_gbps = bytes_float / time_per_iter_float / 1e9 quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9 print('GB/s float', 'GB/s quant', sep='\t') print(float_bw_gbps, quant_bw_gbps, sep='\t') ``` Before this change (DynDisp to AVX2) ``` **** torch.qint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 5.197856426239014 1.2381434440612793 0.23820270175433766 GB/s float GB/s quant 0.6816348335661166 0.7153936841878243 **** torch.quint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 5.14232873916626 1.1790156364440918 0.2292765974808621 GB/s float GB/s quant 0.6889952353715999 0.7512707826941549 **** torch.qint32 ***** time/iter ms (float) time/iter ms (quant) quant/float 4.918942451477051 3.401169776916504 0.6914432950715265 GB/s float GB/s quant 0.7202849057394649 1.041712185038912 ``` After this change (DynDisp to AVX2) ``` **** torch.qint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 5.0574493408203125 0.018107891082763672 0.0035804394394243393 GB/s float GB/s quant 0.700558673203699 48.915690731270566 **** torch.quint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 4.984829425811768 0.016908645629882812 0.0033920209069399163 GB/s float GB/s quant 0.7107645412406512 52.38503540665539 **** torch.qint32 ***** time/iter ms (float) time/iter ms (quant) quant/float 4.973354339599609 0.13938188552856445 0.028025729922108406 GB/s float GB/s quant 0.7124044976624851 25.419658993448625 ``` Test Plan: Imported from OSS Differential Revision: D17196457 Pulled By: jamesr66a fbshipit-source-id: 614be60ed74bed5d0369c58cc450b430cfabe5fb
Author
James Reed
Parents
Loading