pytorch
2397c8d1 - [pytorch] Improve/fix heuristics for using mkldnn vs native conv (#46675)

Commit
4 years ago
[pytorch] Improve/fix heuristics for using mkldnn vs native conv (#46675) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46675 We've found a few heuristics for using/not using mkldnn that seem to generally improve performance on 2d and 3d conv. - 1x1 convolutions are basically batch matmuls, and mkldnn's implementation appears to usually be slower than using the native conv (which lowers to aten::mm, which in turn calls mkl gemm). - 3d conv was often not using mkldnn even when it's beneficial, because the heuristic was checking the kernel depth rather than height/width. mkldnn seems to be faster for (1, 7, 7) and (3, 7, 7) kernel sizes, which are allowed by the new heuristic. Test Plan: Bento notebooks showing before/after: before: https://www.internalfb.com/intern/anp/view/?id=38089 after: https://www.internalfb.com/intern/anp/view/?id=380893 Also, I've run a conv fuzzer, and it generally supports these heuristics. I'm not sure how to best share the data since there's a lot of it (I tried about 50k parameter combinations). For the 1x1 case, about 70% were faster with "native". I played with constructing a decision tree (using scikit-learn) and found that switching back to MKL for batch size > 16 might be slightly better still, but I'm not sure it's worth complicating the heuristic. Results for some popular shapes in tabular format: ``` [------------------------- conv2d_1x1 ------------------------] | base | diff 1 threads: ---------------------------------------------------- [1, 128, 56, 56] [256, 128, 1, 1] | 3665.3 | 2838.4 [1, 512, 14, 14] [1024, 512, 1, 1] | 3174.7 | 3164.0 [1, 64, 56, 56] [256, 64, 1, 1] | 2249.1 | 1468.8 [1, 1024, 14, 14] [512, 1024, 1, 1] | 3158.2 | 3147.7 [1, 1024, 7, 7] [2048, 1024, 1, 1] | 8191.8 | 3973.9 [1, 2048, 7, 7] [1024, 2048, 1, 1] | 7901.2 | 3861.6 [1, 256, 28, 28] [512, 256, 1, 1] | 3103.9 | 2775.9 2 threads: ---------------------------------------------------- [1, 128, 56, 56] [256, 128, 1, 1] | 1973.7 | 1475.8 [1, 512, 14, 14] [1024, 512, 1, 1] | 2265.0 | 1603.0 [1, 64, 56, 56] [256, 64, 1, 1] | 1445.4 | 789.8 [1, 1024, 14, 14] [512, 1024, 1, 1] | 2298.8 | 1620.0 [1, 1024, 7, 7] [2048, 1024, 1, 1] | 6350.7 | 1995.0 [1, 2048, 7, 7] [1024, 2048, 1, 1] | 6471.2 | 1903.7 [1, 256, 28, 28] [512, 256, 1, 1] | 1932.3 | 1524.2 4 threads: ---------------------------------------------------- [1, 128, 56, 56] [256, 128, 1, 1] | 1198.8 | 785.6 [1, 512, 14, 14] [1024, 512, 1, 1] | 1305.0 | 901.6 [1, 64, 56, 56] [256, 64, 1, 1] | 791.0 | 472.9 [1, 1024, 14, 14] [512, 1024, 1, 1] | 1311.2 | 908.5 [1, 1024, 7, 7] [2048, 1024, 1, 1] | 3958.6 | 997.7 [1, 2048, 7, 7] [1024, 2048, 1, 1] | 4099.6 | 1023.1 [1, 256, 28, 28] [512, 256, 1, 1] | 1120.3 | 740.8 Times are in microseconds (us). [--------------------- conv2d_7x7 ---------------------] | base | diff 1 threads: --------------------------------------------- [25, 3, 48, 320] [64, 3, 7, 7] | 209.3 | 229.3 [1, 3, 384, 288] [64, 3, 7, 7] | 68.9 | 72.3 2 threads: --------------------------------------------- [25, 3, 48, 320] [64, 3, 7, 7] | 116.0 | 117.6 [1, 3, 384, 288] [64, 3, 7, 7] | 40.4 | 38.7 4 threads: --------------------------------------------- [25, 3, 48, 320] [64, 3, 7, 7] | 64.2 | 66.5 [1, 3, 384, 288] [64, 3, 7, 7] | 21.4 | 21.9 Times are in milliseconds (ms). [---------------------------- conv3d ---------------------------] | base | diff 1 threads: ------------------------------------------------------ [1, 3, 16, 224, 224] [32, 3, 1, 7, 7] | 602.8 | 296.2 [1, 3, 4, 112, 112] [64, 3, 3, 7, 7] | 52.5 | 26.5 [1, 256, 8, 14, 14] [256, 256, 3, 3, 3] | 50.0 | 50.3 2 threads: ------------------------------------------------------ [1, 3, 16, 224, 224] [32, 3, 1, 7, 7] | 351.0 | 168.1 [1, 3, 4, 112, 112] [64, 3, 3, 7, 7] | 38.5 | 14.9 [1, 256, 8, 14, 14] [256, 256, 3, 3, 3] | 24.8 | 26.2 4 threads: ------------------------------------------------------ [1, 3, 16, 224, 224] [32, 3, 1, 7, 7] | 212.6 | 96.0 [1, 3, 4, 112, 112] [64, 3, 3, 7, 7] | 21.5 | 7.6 [1, 256, 8, 14, 14] [256, 256, 3, 3, 3] | 12.7 | 13.3 Times are in milliseconds (ms). ``` Reviewed By: jansel Differential Revision: D24452071 fbshipit-source-id: 12687971be531831530dc29bf2fc079a917d0c8d
Author
Parents
Loading