Vectorized specialization of max_pool2d for channels-last layout (#25676)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25676
This PR achieves two things:
1) Ensures the channels-last layout is propagated through the operator if we receive an input in that layout. This helps to alleviate unnecessary data movement in, e.g. ResNet inference
2) Applies interleaved vectorization along the channel dimension in the kernel. This allows us to use the functional units on the CPU much more effectively.
Benchmark script
```
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 56, 56, 256)
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 3, 1, 2])
x = x.permute([0, 3, 1, 2])
NITER = 100
s = time.time()
for i in range(NITER):
float_out = torch.max_pool2d(x, kernel_size=3, stride=None, padding=0, dilation=1)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.max_pool2d(q_x, kernel_size=3, stride=None, padding=0, dilation=1)
time_per_iter_quant = (time.time() - s) / NITER
ref_quantized = torch.quantize_linear(float_out, 0.5, 1, dtype)
torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize())
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
bytes_float = (x.numel() + float_out.numel()) * x.element_size()
bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size()
float_bw_gbps = bytes_float / time_per_iter_float / 1e9
quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9
print('GB/s float', 'GB/s quant', sep='\t')
print(float_bw_gbps, quant_bw_gbps, sep='\t')
```
Before this change (DynDisp to AVX2)
```
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
5.197856426239014 1.2381434440612793 0.23820270175433766
GB/s float GB/s quant
0.6816348335661166 0.7153936841878243
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
5.14232873916626 1.1790156364440918 0.2292765974808621
GB/s float GB/s quant
0.6889952353715999 0.7512707826941549
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
4.918942451477051 3.401169776916504 0.6914432950715265
GB/s float GB/s quant
0.7202849057394649 1.041712185038912
```
After this change (DynDisp to AVX2)
```
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
5.0574493408203125 0.018107891082763672 0.0035804394394243393
GB/s float GB/s quant
0.700558673203699 48.915690731270566
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
4.984829425811768 0.016908645629882812 0.0033920209069399163
GB/s float GB/s quant
0.7107645412406512 52.38503540665539
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
4.973354339599609 0.13938188552856445 0.028025729922108406
GB/s float GB/s quant
0.7124044976624851 25.419658993448625
```
Test Plan: Imported from OSS
Differential Revision: D17196457
Pulled By: jamesr66a
fbshipit-source-id: 614be60ed74bed5d0369c58cc450b430cfabe5fb