[pytorch] Improve/fix heuristics for using mkldnn vs native conv (#46675)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46675
We've found a few heuristics for using/not using mkldnn that seem to generally
improve performance on 2d and 3d conv.
- 1x1 convolutions are basically batch matmuls, and mkldnn's implementation
appears to usually be slower than using the native conv (which lowers to
aten::mm, which in turn calls mkl gemm).
- 3d conv was often not using mkldnn even when it's beneficial, because the
heuristic was checking the kernel depth rather than height/width. mkldnn
seems to be faster for (1, 7, 7) and (3, 7, 7) kernel sizes, which are
allowed by the new heuristic.
Test Plan:
Bento notebooks showing before/after:
before: https://www.internalfb.com/intern/anp/view/?id=38089
after: https://www.internalfb.com/intern/anp/view/?id=380893
Also, I've run a conv fuzzer, and it generally supports these heuristics. I'm
not sure how to best share the data since there's a lot of it (I tried about
50k parameter combinations).
For the 1x1 case, about 70% were faster with "native". I played with
constructing a decision tree (using scikit-learn) and found that switching back
to MKL for batch size > 16 might be slightly better still, but I'm not sure
it's worth complicating the heuristic.
Results for some popular shapes in tabular format:
```
[------------------------- conv2d_1x1 ------------------------]
| base | diff
1 threads: ----------------------------------------------------
[1, 128, 56, 56] [256, 128, 1, 1] | 3665.3 | 2838.4
[1, 512, 14, 14] [1024, 512, 1, 1] | 3174.7 | 3164.0
[1, 64, 56, 56] [256, 64, 1, 1] | 2249.1 | 1468.8
[1, 1024, 14, 14] [512, 1024, 1, 1] | 3158.2 | 3147.7
[1, 1024, 7, 7] [2048, 1024, 1, 1] | 8191.8 | 3973.9
[1, 2048, 7, 7] [1024, 2048, 1, 1] | 7901.2 | 3861.6
[1, 256, 28, 28] [512, 256, 1, 1] | 3103.9 | 2775.9
2 threads: ----------------------------------------------------
[1, 128, 56, 56] [256, 128, 1, 1] | 1973.7 | 1475.8
[1, 512, 14, 14] [1024, 512, 1, 1] | 2265.0 | 1603.0
[1, 64, 56, 56] [256, 64, 1, 1] | 1445.4 | 789.8
[1, 1024, 14, 14] [512, 1024, 1, 1] | 2298.8 | 1620.0
[1, 1024, 7, 7] [2048, 1024, 1, 1] | 6350.7 | 1995.0
[1, 2048, 7, 7] [1024, 2048, 1, 1] | 6471.2 | 1903.7
[1, 256, 28, 28] [512, 256, 1, 1] | 1932.3 | 1524.2
4 threads: ----------------------------------------------------
[1, 128, 56, 56] [256, 128, 1, 1] | 1198.8 | 785.6
[1, 512, 14, 14] [1024, 512, 1, 1] | 1305.0 | 901.6
[1, 64, 56, 56] [256, 64, 1, 1] | 791.0 | 472.9
[1, 1024, 14, 14] [512, 1024, 1, 1] | 1311.2 | 908.5
[1, 1024, 7, 7] [2048, 1024, 1, 1] | 3958.6 | 997.7
[1, 2048, 7, 7] [1024, 2048, 1, 1] | 4099.6 | 1023.1
[1, 256, 28, 28] [512, 256, 1, 1] | 1120.3 | 740.8
Times are in microseconds (us).
[--------------------- conv2d_7x7 ---------------------]
| base | diff
1 threads: ---------------------------------------------
[25, 3, 48, 320] [64, 3, 7, 7] | 209.3 | 229.3
[1, 3, 384, 288] [64, 3, 7, 7] | 68.9 | 72.3
2 threads: ---------------------------------------------
[25, 3, 48, 320] [64, 3, 7, 7] | 116.0 | 117.6
[1, 3, 384, 288] [64, 3, 7, 7] | 40.4 | 38.7
4 threads: ---------------------------------------------
[25, 3, 48, 320] [64, 3, 7, 7] | 64.2 | 66.5
[1, 3, 384, 288] [64, 3, 7, 7] | 21.4 | 21.9
Times are in milliseconds (ms).
[---------------------------- conv3d ---------------------------]
| base | diff
1 threads: ------------------------------------------------------
[1, 3, 16, 224, 224] [32, 3, 1, 7, 7] | 602.8 | 296.2
[1, 3, 4, 112, 112] [64, 3, 3, 7, 7] | 52.5 | 26.5
[1, 256, 8, 14, 14] [256, 256, 3, 3, 3] | 50.0 | 50.3
2 threads: ------------------------------------------------------
[1, 3, 16, 224, 224] [32, 3, 1, 7, 7] | 351.0 | 168.1
[1, 3, 4, 112, 112] [64, 3, 3, 7, 7] | 38.5 | 14.9
[1, 256, 8, 14, 14] [256, 256, 3, 3, 3] | 24.8 | 26.2
4 threads: ------------------------------------------------------
[1, 3, 16, 224, 224] [32, 3, 1, 7, 7] | 212.6 | 96.0
[1, 3, 4, 112, 112] [64, 3, 3, 7, 7] | 21.5 | 7.6
[1, 256, 8, 14, 14] [256, 256, 3, 3, 3] | 12.7 | 13.3
Times are in milliseconds (ms).
```
Reviewed By: jansel
Differential Revision: D24452071
fbshipit-source-id: 12687971be531831530dc29bf2fc079a917d0c8d