improve batch_norm contiguous case's performance (#34530)
Summary:
For batch_norm inference contiguous case, we can get a better performance by manually vectorize it.
Test script:
``` X
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
for n in [1, 10, 100]:
for c in [1, 10, 100]:
for hw in [1, 10, 200]:
m = nn.BatchNorm2d(c, affine=False)
m.eval()
input = torch.randn(20, c, hw, hw)
# warm up
for i in range(200):
output = m(input)
fwd_t = 0
for j in range(1000):
t1 = time.time()
output = m(input)
t2 = time.time()
fwd_t = fwd_t + (t2 -t1)
fwd_avg = fwd_t / 1000 * 1000
print("size = (%d, %d, %d, %d); compute time is %.4f(ms)" % (n, c, hw, hw, fwd_avg))
```
Before:
```
size = (1, 1, 1, 1); compute time is 0.0110(ms)
size = (1, 1, 10, 10); compute time is 0.0123(ms)
size = (1, 1, 200, 200); compute time is 0.8166(ms)
size = (1, 10, 1, 1); compute time is 0.0107(ms)
size = (1, 10, 10, 10); compute time is 0.0257(ms)
size = (1, 10, 200, 200); compute time is 8.7533(ms)
size = (1, 100, 1, 1); compute time is 0.0122(ms)
size = (1, 100, 10, 10); compute time is 0.1619(ms)
size = (1, 100, 200, 200); compute time is 123.5674(ms)
size = (10, 1, 1, 1); compute time is 0.0109(ms)
size = (10, 1, 10, 10); compute time is 0.0123(ms)
size = (10, 1, 200, 200); compute time is 0.5629(ms)
size = (10, 10, 1, 1); compute time is 0.0107(ms)
size = (10, 10, 10, 10); compute time is 0.0253(ms)
size = (10, 10, 200, 200); compute time is 8.7817(ms)
size = (10, 100, 1, 1); compute time is 0.0120(ms)
size = (10, 100, 10, 10); compute time is 0.1655(ms)
size = (10, 100, 200, 200); compute time is 123.2488(ms)
size = (100, 1, 1, 1); compute time is 0.0109(ms)
size = (100, 1, 10, 10); compute time is 0.0123(ms)
size = (100, 1, 200, 200); compute time is 0.5740(ms)
size = (100, 10, 1, 1); compute time is 0.0108(ms)
size = (100, 10, 10, 10); compute time is 0.0257(ms)
size = (100, 10, 200, 200); compute time is 8.7201(ms)
size = (100, 100, 1, 1); compute time is 0.0122(ms)
size = (100, 100, 10, 10); compute time is 0.1628(ms)
size = (100, 100, 200, 200); compute time is 123.1739(ms)
```
After:
```
size = (1, 1, 1, 1); compute time is 0.0105(ms)
size = (1, 1, 10, 10); compute time is 0.0114(ms)
size = (1, 1, 200, 200); compute time is 0.5771(ms)
size = (1, 10, 1, 1); compute time is 0.0105(ms)
size = (1, 10, 10, 10); compute time is 0.0160(ms)
size = (1, 10, 200, 200); compute time is 6.9851(ms)
size = (1, 100, 1, 1); compute time is 0.0122(ms)
size = (1, 100, 10, 10); compute time is 0.0848(ms)
size = (1, 100, 200, 200); compute time is 98.6758(ms)
size = (10, 1, 1, 1); compute time is 0.0105(ms)
size = (10, 1, 10, 10); compute time is 0.0115(ms)
size = (10, 1, 200, 200); compute time is 0.2690(ms)
size = (10, 10, 1, 1); compute time is 0.0105(ms)
size = (10, 10, 10, 10); compute time is 0.0159(ms)
size = (10, 10, 200, 200); compute time is 6.6946(ms)
size = (10, 100, 1, 1); compute time is 0.0123(ms)
size = (10, 100, 10, 10); compute time is 0.0854(ms)
size = (10, 100, 200, 200); compute time is 98.7327(ms)
size = (100, 1, 1, 1); compute time is 0.0107(ms)
size = (100, 1, 10, 10); compute time is 0.0116(ms)
size = (100, 1, 200, 200); compute time is 0.2681(ms)
size = (100, 10, 1, 1); compute time is 0.0104(ms)
size = (100, 10, 10, 10); compute time is 0.0159(ms)
size = (100, 10, 200, 200); compute time is 6.7507(ms)
size = (100, 100, 1, 1); compute time is 0.0124(ms)
size = (100, 100, 10, 10); compute time is 0.0852(ms)
size = (100, 100, 200, 200); compute time is 98.6866(ms)
```
For real modle Resnext101, we can also get **~20%** performance improvement for large batch size,
Test script:
```
import torch
import torchvision
import torch
import time
torch.manual_seed(0)
#torch.set_num_threads(1)
model = torchvision.models.resnext101_32x8d().eval()
for batch_size in [1, 64]:
input = torch.randn(batch_size, 3, 224, 224)
#warm up
with torch.no_grad():
for i in range(5):
output = model(input)
fwd_t = 0
for i in range(10):
t1 = time.time()
output = model(input)
t2 = time.time()
fwd_t = fwd_t + (t2 - t1)
time_fwd_avg = fwd_t / 10 * 1000
print("Throughput of resnext101 with batch_size = %d is %10.2f (imgs/s)" % (batch_size, batch_size * 1000/ time_fwd_avg ))
```
Before:
```
Throughput of resnext101 with batch_size = 1 is 7.89 (imgs/s)
Throughput of resnext101 with batch_size = 64 is 13.02 (imgs/s)
num_threads =1
Throughput of resnext101 with batch_size = 1 is 2.97 (imgs/s)
Throughput of resnext101 with batch_size = 64 is 2.75 (imgs/s)
```
After:
```
Throughput of resnext101 with batch_size = 1 is 8.95 (imgs/s)
Throughput of resnext101 with batch_size = 64 is 15.52 (imgs/s)
num_threads = 1
Throughput of resnext101 with batch_size = 1 is 3.10 (imgs/s)
Throughput of resnext101 with batch_size = 64 is 2.88 (imgs/s)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34530
Differential Revision: D20479560
Pulled By: ngimel
fbshipit-source-id: 2e788ebcd814556116c90553ec61159eeffb3c16