pytorch
f2a35db2 - batch_norm_cpu_inference for channel last (#28982)

Commit
6 years ago
batch_norm_cpu_inference for channel last (#28982) Summary: channels last version for batch_norm_cpu_inference_contiguous Benchmark: The benchmark test uses a fixed batch size n=20, channel number in [1,3,10,100,1000], height and width size in [1,4,16,64,256], height and width size are always the same in this test. We use the following code to do this benchmark. It tests contiguous, channels last and non-contiguous tensor in each loop and print out the benchmark. It also compare the outputs within each loop to make sure the correctness of the new change. for c in [1,3,10,100,1000]: for hw in [1,4,16,64,256]: print('Benchmark n=20 c={0} h={1} w={2}'.format(c, hw, hw)) m = nn.BatchNorm2d(c, affine=False) m.eval() input = torch.randn(20, c, hw, hw) output = m(input) %timeit m(input) for name, param in m.named_parameters(): if param.requires_grad: if param.data.dim() == 4: param.data = param.data.contiguous(memory_format=torch.channels_last) m.eval() input = input.contiguous(memory_format=torch.channels_last) output1 = m(input) %timeit m(input) m = nn.BatchNorm2d(c, affine=False) m.eval() input = input.permute(0,1,3,2) output2 = m(input) %timeit m(input) output2 = output2.permute(0,1,3,2) print(output.equal(output1), output.equal(output2)) Sample output: Benchmark n=20 c=100 h=256 w=256 -> title line 101 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> contiguous tensor 100 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) -> channels last tensor 1.3 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> non-contiguous tensor True True -> 1st output compare with 2nd output, 1st output compare 3rd output, expect True **Benchmark Before this change:** Benchmark n=20 c=1 h=1 w=1 10.1 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.2 µs ± 305 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.7 µs ± 784 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1 h=4 w=4 10.2 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.1 µs ± 98 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 12.5 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1 h=16 w=16 11 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 11 µs ± 148 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 17.3 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1 h=64 w=64 24.2 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 23.9 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 66 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=1 h=256 w=256 539 µs ± 7.85 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 539 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.42 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=3 h=1 w=1 10 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 9.97 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.4 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=3 h=4 w=4 10.4 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 16.1 µs ± 601 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 19.1 µs ± 658 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=3 h=16 w=16 13.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 25.3 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 32.4 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=3 h=64 w=64 51.1 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 159 µs ± 7.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 199 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=3 h=256 w=256 1.25 ms ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.95 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 6.14 ms ± 42.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) True True Benchmark n=20 c=10 h=1 w=1 9.97 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.5 µs ± 852 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 11.7 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=10 h=4 w=4 11.2 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 29.7 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 39.4 µs ± 396 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=10 h=16 w=16 19.7 µs ± 632 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 68.3 µs ± 912 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 90.3 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=10 h=64 w=64 325 µs ± 5.01 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 918 µs ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 991 µs ± 44.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=10 h=256 w=256 9.47 ms ± 73.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 34.7 ms ± 2.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 91.5 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) True True Benchmark n=20 c=100 h=1 w=1 11.8 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each) 12.1 µs ± 800 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 12 µs ± 533 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=100 h=4 w=4 26.7 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 231 µs ± 8.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 335 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=100 h=16 w=16 178 µs ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 1.45 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.52 ms ± 94.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=100 h=64 w=64 6.9 ms ± 554 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 30.3 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 27 ms ± 272 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) True True Benchmark n=20 c=100 h=256 w=256 98.9 ms ± 818 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.29 s ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.32 s ± 9.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) True True Benchmark n=20 c=1000 h=1 w=1 18.6 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 18.7 µs ± 947 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 15.8 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1000 h=4 w=4 111 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 2.07 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 3.19 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) True True Benchmark n=20 c=1000 h=16 w=16 3.87 ms ± 336 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 25.6 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 27 ms ± 410 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) True True Benchmark n=20 c=1000 h=64 w=64 70.1 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 467 ms ± 26.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 444 ms ± 25.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) True True Benchmark n=20 c=1000 h=256 w=256 2.39 s ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 19.2 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 22.1 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each) True True **Benchmark After this change:** Benchmark n=20 c=1 h=1 w=1 10.4 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.5 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.7 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1 h=4 w=4 11.8 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each) 11 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 13.6 µs ± 142 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1 h=16 w=16 11.9 µs ± 198 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 12.1 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 18.2 µs ± 205 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1 h=64 w=64 27.6 µs ± 2.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 32.2 µs ± 8.69 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 68.9 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=1 h=256 w=256 601 µs ± 49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 597 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.48 ms ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=3 h=1 w=1 10.8 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.6 µs ± 194 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.5 µs ± 137 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=3 h=4 w=4 11.6 µs ± 551 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 11.7 µs ± 266 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 19.9 µs ± 340 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=3 h=16 w=16 13.7 µs ± 223 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 24.7 µs ± 424 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 33.7 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=3 h=64 w=64 53.3 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 212 µs ± 4.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 204 µs ± 5.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=3 h=256 w=256 1.49 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.27 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.08 ms ± 290 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) True True Benchmark n=20 c=10 h=1 w=1 10.7 µs ± 166 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.8 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 10.8 µs ± 192 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=10 h=4 w=4 11.6 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 12.9 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 43.7 µs ± 3.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=10 h=16 w=16 20.7 µs ± 576 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 37.2 µs ± 795 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 92.5 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) True True Benchmark n=20 c=10 h=64 w=64 342 µs ± 9.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 622 µs ± 37.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.03 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=10 h=256 w=256 9.49 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10.9 ms ± 408 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 90.5 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) True True Benchmark n=20 c=100 h=1 w=1 12 µs ± 575 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 11 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 11 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=100 h=4 w=4 22.3 µs ± 451 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 18.7 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 323 µs ± 6.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=100 h=16 w=16 211 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 222 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.5 ms ± 59.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) True True Benchmark n=20 c=100 h=64 w=64 7.2 ms ± 1e+03 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 6.51 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 27.4 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) True True Benchmark n=20 c=100 h=256 w=256 101 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 100 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.3 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) True True Benchmark n=20 c=1000 h=1 w=1 16.9 µs ± 589 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 16.5 µs ± 113 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 16.5 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) True True Benchmark n=20 c=1000 h=4 w=4 116 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 67 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 3.23 ms ± 80 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) True True Benchmark n=20 c=1000 h=16 w=16 3.53 ms ± 72.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 3.53 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 27 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) True True Benchmark n=20 c=1000 h=64 w=64 68.6 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 68 ms ± 288 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 425 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) True True Benchmark n=20 c=1000 h=256 w=256 2.51 s ± 97.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 2.84 s ± 471 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 21.5 s ± 933 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) True True The channel last batch normalization is getting faster with this change and the previous existing code/logic is not affected based on the benchmark above. Pull Request resolved: https://github.com/pytorch/pytorch/pull/28982 Reviewed By: VitalyFedyunin Differential Revision: D18253305 Pulled By: glaringlee fbshipit-source-id: a0fcac65544f10d736141ee70edeab8a3f1b3e02
Author
Xinyu Li
Parents
Loading