batch_norm_cpu_inference for channel last (#28982)
Summary:
channels last version for batch_norm_cpu_inference_contiguous
Benchmark:
The benchmark test uses a fixed batch size n=20, channel number in [1,3,10,100,1000], height and width size in [1,4,16,64,256], height and width size are always the same in this test.
We use the following code to do this benchmark.
It tests contiguous, channels last and non-contiguous tensor in each loop and print out the benchmark. It also compare the outputs within each loop to make sure the correctness of the new change.
for c in [1,3,10,100,1000]:
for hw in [1,4,16,64,256]:
print('Benchmark n=20 c={0} h={1} w={2}'.format(c, hw, hw))
m = nn.BatchNorm2d(c, affine=False)
m.eval()
input = torch.randn(20, c, hw, hw)
output = m(input)
%timeit m(input)
for name, param in m.named_parameters():
if param.requires_grad:
if param.data.dim() == 4:
param.data = param.data.contiguous(memory_format=torch.channels_last)
m.eval()
input = input.contiguous(memory_format=torch.channels_last)
output1 = m(input)
%timeit m(input)
m = nn.BatchNorm2d(c, affine=False)
m.eval()
input = input.permute(0,1,3,2)
output2 = m(input)
%timeit m(input)
output2 = output2.permute(0,1,3,2)
print(output.equal(output1), output.equal(output2))
Sample output:
Benchmark n=20 c=100 h=256 w=256 -> title line
101 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> contiguous tensor
100 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) -> channels last tensor
1.3 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> non-contiguous tensor
True True -> 1st output compare with 2nd output, 1st output compare 3rd output, expect True
**Benchmark Before this change:**
Benchmark n=20 c=1 h=1 w=1
10.1 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.2 µs ± 305 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.7 µs ± 784 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=4 w=4
10.2 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.1 µs ± 98 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.5 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=16 w=16
11 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 148 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
17.3 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=64 w=64
24.2 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
23.9 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=1 h=256 w=256
539 µs ± 7.85 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
539 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.42 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=3 h=1 w=1
10 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.97 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.4 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=4 w=4
10.4 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
16.1 µs ± 601 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
19.1 µs ± 658 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=16 w=16
13.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
25.3 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32.4 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=3 h=64 w=64
51.1 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
159 µs ± 7.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
199 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=3 h=256 w=256
1.25 ms ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.95 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.14 ms ± 42.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=10 h=1 w=1
9.97 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.5 µs ± 852 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11.7 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=4 w=4
11.2 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
29.7 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
39.4 µs ± 396 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=16 w=16
19.7 µs ± 632 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68.3 µs ± 912 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
90.3 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=64 w=64
325 µs ± 5.01 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
918 µs ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
991 µs ± 44.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=10 h=256 w=256
9.47 ms ± 73.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
34.7 ms ± 2.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
91.5 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=1 w=1
11.8 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.1 µs ± 800 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12 µs ± 533 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=100 h=4 w=4
26.7 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
231 µs ± 8.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
335 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=16 w=16
178 µs ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.45 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.52 ms ± 94.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=64 w=64
6.9 ms ± 554 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
30.3 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
27 ms ± 272 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=256 w=256
98.9 ms ± 818 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.29 s ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.32 s ± 9.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=1 w=1
18.6 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
18.7 µs ± 947 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
15.8 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1000 h=4 w=4
111 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2.07 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.19 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=1000 h=16 w=16
3.87 ms ± 336 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
25.6 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
27 ms ± 410 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=1000 h=64 w=64
70.1 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
467 ms ± 26.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
444 ms ± 25.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=256 w=256
2.39 s ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
19.2 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
22.1 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
**Benchmark After this change:**
Benchmark n=20 c=1 h=1 w=1
10.4 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.5 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.7 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=4 w=4
11.8 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
13.6 µs ± 142 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=16 w=16
11.9 µs ± 198 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.1 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
18.2 µs ± 205 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=64 w=64
27.6 µs ± 2.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32.2 µs ± 8.69 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68.9 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=1 h=256 w=256
601 µs ± 49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
597 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.48 ms ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=3 h=1 w=1
10.8 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.6 µs ± 194 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.5 µs ± 137 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=4 w=4
11.6 µs ± 551 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11.7 µs ± 266 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
19.9 µs ± 340 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=16 w=16
13.7 µs ± 223 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
24.7 µs ± 424 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33.7 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=3 h=64 w=64
53.3 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
212 µs ± 4.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
204 µs ± 5.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=3 h=256 w=256
1.49 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.27 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.08 ms ± 290 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=10 h=1 w=1
10.7 µs ± 166 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.8 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.8 µs ± 192 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=10 h=4 w=4
11.6 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.9 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
43.7 µs ± 3.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=16 w=16
20.7 µs ± 576 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
37.2 µs ± 795 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
92.5 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=64 w=64
342 µs ± 9.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
622 µs ± 37.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.03 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=10 h=256 w=256
9.49 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.9 ms ± 408 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
90.5 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=1 w=1
12 µs ± 575 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=100 h=4 w=4
22.3 µs ± 451 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
18.7 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
323 µs ± 6.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=16 w=16
211 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
222 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.5 ms ± 59.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=64 w=64
7.2 ms ± 1e+03 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.51 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
27.4 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=256 w=256
101 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
100 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.3 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=1 w=1
16.9 µs ± 589 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
16.5 µs ± 113 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
16.5 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1000 h=4 w=4
116 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
67 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3.23 ms ± 80 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=1000 h=16 w=16
3.53 ms ± 72.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.53 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
27 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=1000 h=64 w=64
68.6 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
68 ms ± 288 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
425 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=256 w=256
2.51 s ± 97.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.84 s ± 471 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
21.5 s ± 933 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
The channel last batch normalization is getting faster with this change and the previous existing code/logic is not affected based on the benchmark above.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28982
Reviewed By: VitalyFedyunin
Differential Revision: D18253305
Pulled By: glaringlee
fbshipit-source-id: a0fcac65544f10d736141ee70edeab8a3f1b3e02