NHWC specialization for quantized::cat (#26524)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26524
This creates an NHWC specialization for `quantized::cat` that kicks in when all inputs are `NHWC`. This ensures the correct layout is propagated downstream as well as is an optimized implementation specifically for this data layout
Benchmark script based on Squeezenet shapes:
```
import torch, time
torch.manual_seed(0)
# NHWC
sizes = [
(1, 54, 54, 64),
(1, 54, 54, 128),
(1, 26, 26, 128),
(1, 26, 26, 256),
(1, 12, 12, 256)
]
for size in sizes:
x = torch.rand(*size)
y = torch.rand(*size)
qX = torch.quantize_linear(x, 0.01, 3, torch.qint8).permute([0, 3, 1, 2])
qY = torch.quantize_linear(y, 0.01, 3, torch.qint8).permute([0, 3, 1, 2])
ref = torch.cat([qX.dequantize(), qY.dequantize()], dim=1)
NITER = 1000
s = time.time()
for i in range(NITER):
out = torch.ops.quantized.cat([qX, qY], dim=1, scale=0.01, zero_point=3)
time_per_iter = (time.time() - s) / NITER
print('time per iter ms', time_per_iter * 1000)
print('gb/s', (qX.numel() + qY.numel() + out.numel()) * qX.element_size() / time_per_iter / 1e9)
torch.testing.assert_allclose(out.dequantize(), ref)
```
Before this change
```
time per iter ms 0.6898486614227295
gb/s 1.0821156026605054
time per iter ms 1.5480577945709229
gb/s 0.9644291093239284
time per iter ms 0.3180875778198242
gb/s 1.0881028500775023
time per iter ms 0.6702737808227539
gb/s 1.032748139350315
time per iter ms 0.13010454177856445
gb/s 1.1333655073392244
```
After this change
```
time per iter ms 0.11604785919189453
gb/s 6.432656364350577
time per iter ms 0.15956878662109375
gb/s 9.356416324360508
time per iter ms 0.040181636810302734
gb/s 8.613685939027139
time per iter ms 0.06564664840698242
gb/s 10.544696748392909
time per iter ms 0.018549680709838867
gb/s 7.949247337814738
```
Test Plan: Imported from OSS
Differential Revision: D17503593
Pulled By: jamesr66a
fbshipit-source-id: ec5d57ad8fbcb3fd9379e8bd370abd29d386f953