improve the quantized batch_norm performance (#35639)
Summary:
The original batch_norm performance is 2X slower than C2 for some shape, especially for the remaining channel size close to 32. For example, we have a total channel size 32*1 + 24. The 24 channel execution in original implementation will be slow.
Benchmark
```
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 4, 56, 56, 24)
q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 4, 1, 2, 3])
c = 24
mean = torch.rand(c).float()
var = torch.rand(c).float()
weight = torch.rand(c).float()
bias = torch.rand(c).float()
eps = 0.001
x = x.permute([0, 4, 1, 2, 3])
NITER = 10
s = time.time()
for i in range(NITER):
float_out = torch.nn.functional.batch_norm(x, weight=weight, bias=bias, running_mean=mean, running_var=var, training=False, momentum=0, eps=eps)
float_out = torch.nn.functional.relu(float_out)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.ops.quantized.batch_norm3d_relu(q_x, weight, bias, mean, var, eps, 0.5, 1)
time_per_iter_quant = (time.time() - s) / NITER
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
```
```
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
0.6527423858642578 1.649641990661621 2.5272481554532837
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
0.5787134170532227 1.040959358215332 1.7987475796152104
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
0.5466938018798828 2.262735366821289 4.138944614042739
```
//Before the change:
```
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
0.7526159286499023 2.330636978149414 3.0967149238128426
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
0.21767616271972656 1.3946294784545898 6.406900328587075
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
0.24483203887939456 2.561521530151367 10.46236245009251
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35639
Differential Revision: D20723292
Pulled By: lly-zero-one
fbshipit-source-id: 66692eabaffb5030c2a37ec0f1322df3665411aa