Don't implicitly convert to channels-first in MaxPool3D on CUDA (#80748)
MaxPool3D currently converts inputs implicitly to channels-first (via `.contiguous()`) which may yield unexpected regressions in workloads that expect a full channels-last path. This PR preserves the channels-last format in MaxPool3D while attempting to avoid seriously regressing performance.
Currently, typical case (kernel size == 2 == stride) looks good, but larger kernel sizes (>4) or the unusual case of stride 1 can sometimes be slower than converting to channels-first before doing MaxPool3D.
Additionally, this PR adds a test for 64bit-indexing backwards as testing of these changes uncovered an IMA for large tensors when doing the backwards pass with MaxPool3D.
Performance comparison on A6000:
```
[------------------------------------- max_pool3d ---------------------------------------------------------]
| channels_last=False | curr ch_last=True | new ch_last=True
1 threads: ---------------------------------------------------------------------------- ---------------------
[64, 256, 32, 32, 32] 4x4 stride 4 | 20093.5 | 34823.4 | 20640.0
[64, 256, 32, 32, 32] 4x4 stride 2 | 28623.7 | 42625.6 | 27935.5
[64, 256, 32, 32, 32] 4x4 stride 1 | 68177.5 | 79147.2 | 85604.8
[64, 256, 32, 32, 32] 2x2 stride 4 | 17237.7 | 32071.3 | 16641.6
[64, 256, 32, 32, 32] 2x2 stride 2 | 25252.5 | 39993.2 | 25054.8
[64, 256, 32, 32, 32] 2x2 stride 1 | 43185.2 | 58164.6 | 48416.9
[64, 256, 16, 16, 16] 4x4 stride 4 | 3017.7 | 3952.4 | 2593.8
[64, 256, 16, 16, 16] 4x4 stride 2 | 4581.5 | 5384.3 | 3294.3
[64, 256, 16, 16, 16] 4x4 stride 1 | 11334.1 | 11534.7 | 8651.1
[64, 256, 16, 16, 16] 2x2 stride 4 | 2346.9 | 3304.6 | 2098.8
[64, 256, 16, 16, 16] 2x2 stride 2 | 3550.8 | 4526.5 | 3143.6
[64, 256, 16, 16, 16] 2x2 stride 1 | 6898.1 | 7816.0 | 5820.8
[64, 256, 4, 4, 4] 4x4 stride 4 | 191.5 | 176.3 | 77.5
[64, 256, 4, 4, 4] 4x4 stride 2 | 191.8 | 176.8 | 94.1
[64, 256, 4, 4, 4] 4x4 stride 1 | 191.3 | 176.4 | 97.3
[64, 256, 4, 4, 4] 2x2 stride 4 | 96.4 | 114.4 | 93.6
[64, 256, 4, 4, 4] 2x2 stride 2 | 172.1 | 178.6 | 93.7
[64, 256, 4, 4, 4] 2x2 stride 1 | 263.0 | 279.4 | 92.4
[64, 64, 32, 32, 32] 4x4 stride 4 | 5033.2 | 7208.3 | 5167.5
[64, 64, 32, 32, 32] 4x4 stride 2 | 7216.1 | 9218.7 | 6637.1
[64, 64, 32, 32, 32] 4x4 stride 1 | 17192.1 | 18392.9 | 20489.0
[64, 64, 32, 32, 32] 2x2 stride 4 | 4318.0 | 6511.2 | 4193.1
[64, 64, 32, 32, 32] 2x2 stride 2 | 6324.4 | 8657.7 | 6263.6
[64, 64, 32, 32, 32] 2x2 stride 1 | 10855.0 | 13040.2 | 12055.9
[64, 64, 16, 16, 16] 4x4 stride 4 | 764.1 | 975.6 | 671.3
[64, 64, 16, 16, 16] 4x4 stride 2 | 1163.1 | 1333.4 | 833.6
[64, 64, 16, 16, 16] 4x4 stride 1 | 2890.0 | 2898.5 | 2209.8
[64, 64, 16, 16, 16] 2x2 stride 4 | 593.5 | 811.2 | 536.3
[64, 64, 16, 16, 16] 2x2 stride 2 | 895.9 | 1112.3 | 794.5
[64, 64, 16, 16, 16] 2x2 stride 1 | 1742.5 | 1968.0 | 1475.2
[64, 64, 4, 4, 4] 4x4 stride 4 | 101.1 | 112.2 | 93.4
[64, 64, 4, 4, 4] 4x4 stride 2 | 96.7 | 114.6 | 92.5
[64, 64, 4, 4, 4] 4x4 stride 1 | 98.9 | 111.9 | 96.5
[64, 64, 4, 4, 4] 2x2 stride 4 | 100.1 | 107.1 | 94.2
[64, 64, 4, 4, 4] 2x2 stride 2 | 96.6 | 108.0 | 94.5
[64, 64, 4, 4, 4] 2x2 stride 1 | 96.7 | 107.9 | 95.2
[64, 3, 32, 32, 32] 4x4 stride 4 | 250.1 | 326.6 | 278.0
[64, 3, 32, 32, 32] 4x4 stride 2 | 350.4 | 414.0 | 323.2
[64, 3, 32, 32, 32] 4x4 stride 1 | 825.6 | 846.9 | 982.5
[64, 3, 32, 32, 32] 2x2 stride 4 | 213.3 | 289.8 | 219.9
[64, 3, 32, 32, 32] 2x2 stride 2 | 308.2 | 384.9 | 305.9
[64, 3, 32, 32, 32] 2x2 stride 1 | 523.5 | 594.7 | 589.9
[64, 3, 16, 16, 16] 4x4 stride 4 | 103.8 | 116.7 | 93.0
[64, 3, 16, 16, 16] 4x4 stride 2 | 100.9 | 108.3 | 93.3
[64, 3, 16, 16, 16] 4x4 stride 1 | 139.4 | 140.7 | 104.8
[64, 3, 16, 16, 16] 2x2 stride 4 | 97.5 | 114.7 | 92.7
[64, 3, 16, 16, 16] 2x2 stride 2 | 97.4 | 108.8 | 91.7
[64, 3, 16, 16, 16] 2x2 stride 1 | 99.9 | 108.0 | 94.1
[64, 3, 4, 4, 4] 4x4 stride 4 | 97.2 | 110.2 | 94.7
[64, 3, 4, 4, 4] 4x4 stride 2 | 105.7 | 107.4 | 92.8
[64, 3, 4, 4, 4] 4x4 stride 1 | 98.0 | 110.0 | 93.7
[64, 3, 4, 4, 4] 2x2 stride 4 | 98.3 | 116.7 | 93.0
[64, 3, 4, 4, 4] 2x2 stride 2 | 98.6 | 107.5 | 92.8
[64, 3, 4, 4, 4] 2x2 stride 1 | 100.6 | 110.3 | 94.0
[16, 256, 32, 32, 32] 4x4 stride 4 | 5034.2 | 8838.0 | 5165.9
[16, 256, 32, 32, 32] 4x4 stride 2 | 7236.3 | 10869.9 | 7038.2
[16, 256, 32, 32, 32] 4x4 stride 1 | 17385.4 | 21401.6 | 21900.7
[16, 256, 32, 32, 32] 2x2 stride 4 | 4318.7 | 8101.2 | 4172.9
[16, 256, 32, 32, 32] 2x2 stride 2 | 6324.0 | 10147.5 | 6279.7
[16, 256, 32, 32, 32] 2x2 stride 1 | 10899.7 | 14826.0 | 12256.3
[16, 256, 16, 16, 16] 4x4 stride 4 | 765.4 | 1012.7 | 675.6
[16, 256, 16, 16, 16] 4x4 stride 2 | 1162.8 | 1376.9 | 843.4
[16, 256, 16, 16, 16] 4x4 stride 1 | 2928.9 | 2969.8 | 2222.5
[16, 256, 16, 16, 16] 2x2 stride 4 | 593.5 | 845.8 | 534.2
[16, 256, 16, 16, 16] 2x2 stride 2 | 896.9 | 1152.2 | 796.9
[16, 256, 16, 16, 16] 2x2 stride 1 | 1750.2 | 2009.4 | 1481.8
[16, 256, 4, 4, 4] 4x4 stride 4 | 96.6 | 107.1 | 92.7
[16, 256, 4, 4, 4] 4x4 stride 2 | 97.9 | 114.9 | 93.8
[16, 256, 4, 4, 4] 4x4 stride 1 | 98.2 | 115.6 | 94.0
[16, 256, 4, 4, 4] 2x2 stride 4 | 97.0 | 106.7 | 93.8
[16, 256, 4, 4, 4] 2x2 stride 2 | 96.8 | 108.1 | 93.3
[16, 256, 4, 4, 4] 2x2 stride 1 | 95.8 | 120.9 | 95.7
[16, 64, 32, 32, 32] 4x4 stride 4 | 1266.4 | 1815.4 | 1312.3
[16, 64, 32, 32, 32] 4x4 stride 2 | 1818.5 | 2328.0 | 1678.9
[16, 64, 32, 32, 32] 4x4 stride 1 | 4352.9 | 4649.3 | 5204.6
[16, 64, 32, 32, 32] 2x2 stride 4 | 1090.0 | 1631.2 | 1060.8
[16, 64, 32, 32, 32] 2x2 stride 2 | 1589.4 | 2141.1 | 1576.4
[16, 64, 32, 32, 32] 2x2 stride 1 | 2733.5 | 3286.0 | 3041.6
[16, 64, 16, 16, 16] 4x4 stride 4 | 201.7 | 259.6 | 175.0
[16, 64, 16, 16, 16] 4x4 stride 2 | 301.0 | 350.1 | 226.3
[16, 64, 16, 16, 16] 4x4 stride 1 | 740.1 | 748.7 | 570.6
[16, 64, 16, 16, 16] 2x2 stride 4 | 156.0 | 214.8 | 140.8
[16, 64, 16, 16, 16] 2x2 stride 2 | 232.3 | 292.3 | 208.7
[16, 64, 16, 16, 16] 2x2 stride 1 | 449.1 | 504.0 | 382.1
[16, 64, 4, 4, 4] 4x4 stride 4 | 97.5 | 111.4 | 94.5
[16, 64, 4, 4, 4] 4x4 stride 2 | 98.8 | 111.9 | 94.4
[16, 64, 4, 4, 4] 4x4 stride 1 | 98.2 | 112.0 | 95.2
[16, 64, 4, 4, 4] 2x2 stride 4 | 99.7 | 111.0 | 94.0
[16, 64, 4, 4, 4] 2x2 stride 2 | 100.3 | 110.0 | 93.2
[16, 64, 4, 4, 4] 2x2 stride 1 | 97.5 | 107.6 | 93.5
[16, 3, 32, 32, 32] 4x4 stride 4 | 100.5 | 117.1 | 95.7
[16, 3, 32, 32, 32] 4x4 stride 2 | 97.5 | 121.3 | 92.5
[16, 3, 32, 32, 32] 4x4 stride 1 | 216.0 | 227.4 | 258.4
[16, 3, 32, 32, 32] 2x2 stride 4 | 97.1 | 109.0 | 91.9
[16, 3, 32, 32, 32] 2x2 stride 2 | 95.8 | 108.5 | 92.9
[16, 3, 32, 32, 32] 2x2 stride 1 | 139.4 | 161.2 | 157.8
[16, 3, 16, 16, 16] 4x4 stride 4 | 96.4 | 113.6 | 91.9
[16, 3, 16, 16, 16] 4x4 stride 2 | 97.4 | 108.1 | 93.5
[16, 3, 16, 16, 16] 4x4 stride 1 | 99.0 | 107.5 | 92.1
[16, 3, 16, 16, 16] 2x2 stride 4 | 96.9 | 118.1 | 93.4
[16, 3, 16, 16, 16] 2x2 stride 2 | 97.3 | 106.7 | 95.8
[16, 3, 16, 16, 16] 2x2 stride 1 | 98.8 | 109.2 | 93.8
[16, 3, 4, 4, 4] 4x4 stride 4 | 97.8 | 108.0 | 94.2
[16, 3, 4, 4, 4] 4x4 stride 2 | 92.7 | 108.0 | 93.9
[16, 3, 4, 4, 4] 4x4 stride 1 | 97.8 | 107.6 | 93.5
[16, 3, 4, 4, 4] 2x2 stride 4 | 100.3 | 107.7 | 94.3
[16, 3, 4, 4, 4] 2x2 stride 2 | 97.2 | 107.5 | 96.1
[16, 3, 4, 4, 4] 2x2 stride 1 | 98.1 | 111.1 | 93.8
Times are in microseconds (us).
```
Performance comparison on V100:
(these times have been updated after working around some noisy measurements in my setup)
```
[------------------------------------- max_pool3d ---------------------------------------------------------]
| channels_last=False | curr ch_last=True | new ch_last=True
1 threads: -------------------------------------------------------------------------------------------------
[64, 256, 32, 32, 32] 4x4 stride 4 | 15810.7 | 33807.7 | 16452.9
[64, 256, 32, 32, 32] 4x4 stride 2 | 24422.7 | 42515.3 | 27700.3
[64, 256, 32, 32, 32] 4x4 stride 1 | 71756.0 | 89916.5 | 106464.0
[64, 256, 32, 32, 32] 2x2 stride 4 | 12102.9 | 30210.4 | 11319.8
[64, 256, 32, 32, 32] 2x2 stride 2 | 19101.7 | 37210.8 | 20373.3
[64, 256, 32, 32, 32] 2x2 stride 1 | 41418.0 | 59650.5 | 53009.2
[64, 256, 16, 16, 16] 4x4 stride 4 | 2362.0 | 4210.3 | 2114.0
[64, 256, 16, 16, 16] 4x4 stride 2 | 4102.4 | 5897.4 | 3179.7
[64, 256, 16, 16, 16] 4x4 stride 1 | 11339.3 | 13116.6 | 10032.6
[64, 256, 16, 16, 16] 2x2 stride 4 | 1709.7 | 3506.7 | 1423.6
[64, 256, 16, 16, 16] 2x2 stride 2 | 2966.6 | 4760.8 | 2499.3
[64, 256, 16, 16, 16] 2x2 stride 1 | 6998.4 | 8797.3 | 6152.0
[64, 256, 4, 4, 4] 4x4 stride 4 | 173.0 | 176.3 | 127.9
[64, 256, 4, 4, 4] 4x4 stride 2 | 149.1 | 176.3 | 125.5
[64, 256, 4, 4, 4] 4x4 stride 1 | 150.0 | 177.2 | 125.6
[64, 256, 4, 4, 4] 2x2 stride 4 | 158.0 | 192.7 | 127.9
[64, 256, 4, 4, 4] 2x2 stride 2 | 169.7 | 199.2 | 125.3
[64, 256, 4, 4, 4] 2x2 stride 1 | 289.6 | 318.2 | 116.5
[64, 64, 32, 32, 32] 4x4 stride 4 | 3914.4 | 6993.3 | 4141.4
[64, 64, 32, 32, 32] 4x4 stride 2 | 6107.4 | 9186.4 | 6378.5
[64, 64, 32, 32, 32] 4x4 stride 1 | 17920.0 | 20993.5 | 23891.1
[64, 64, 32, 32, 32] 2x2 stride 4 | 3029.7 | 6112.6 | 2895.6
[64, 64, 32, 32, 32] 2x2 stride 2 | 4787.8 | 7870.6 | 4724.8
[64, 64, 32, 32, 32] 2x2 stride 1 | 10366.4 | 13446.4 | 12603.8
[64, 64, 16, 16, 16] 4x4 stride 4 | 605.8 | 962.9 | 499.7
[64, 64, 16, 16, 16] 4x4 stride 2 | 1037.0 | 1394.8 | 791.6
[64, 64, 16, 16, 16] 4x4 stride 1 | 2835.4 | 3191.8 | 2484.3
[64, 64, 16, 16, 16] 2x2 stride 4 | 438.6 | 795.7 | 368.6
[64, 64, 16, 16, 16] 2x2 stride 2 | 749.1 | 1108.0 | 612.0
[64, 64, 16, 16, 16] 2x2 stride 1 | 1756.4 | 2112.2 | 1538.5
[64, 64, 4, 4, 4] 4x4 stride 4 | 132.6 | 163.9 | 115.4
[64, 64, 4, 4, 4] 4x4 stride 2 | 129.3 | 153.7 | 117.8
[64, 64, 4, 4, 4] 4x4 stride 1 | 128.0 | 153.8 | 117.6
[64, 64, 4, 4, 4] 2x2 stride 4 | 128.2 | 154.1 | 117.5
[64, 64, 4, 4, 4] 2x2 stride 2 | 130.5 | 157.3 | 117.6
[64, 64, 4, 4, 4] 2x2 stride 1 | 128.8 | 156.4 | 120.6
[64, 3, 32, 32, 32] 4x4 stride 4 | 200.4 | 261.0 | 228.8
[64, 3, 32, 32, 32] 4x4 stride 2 | 305.3 | 366.5 | 344.4
[64, 3, 32, 32, 32] 4x4 stride 1 | 860.9 | 922.1 | 1136.0
[64, 3, 32, 32, 32] 2x2 stride 4 | 157.0 | 216.9 | 158.1
[64, 3, 32, 32, 32] 2x2 stride 2 | 240.5 | 300.9 | 247.7
[64, 3, 32, 32, 32] 2x2 stride 1 | 503.5 | 565.1 | 609.8
[64, 3, 16, 16, 16] 4x4 stride 4 | 136.0 | 159.0 | 120.3
[64, 3, 16, 16, 16] 4x4 stride 2 | 131.2 | 156.9 | 120.0
[64, 3, 16, 16, 16] 4x4 stride 1 | 146.6 | 158.5 | 123.8
[64, 3, 16, 16, 16] 2x2 stride 4 | 133.8 | 158.4 | 117.1
[64, 3, 16, 16, 16] 2x2 stride 2 | 132.1 | 160.8 | 117.9
[64, 3, 16, 16, 16] 2x2 stride 1 | 133.7 | 174.4 | 118.0
[64, 3, 4, 4, 4] 4x4 stride 4 | 156.8 | 166.2 | 119.4
[64, 3, 4, 4, 4] 4x4 stride 2 | 126.8 | 150.4 | 118.2
[64, 3, 4, 4, 4] 4x4 stride 1 | 125.2 | 151.7 | 117.8
[64, 3, 4, 4, 4] 2x2 stride 4 | 127.3 | 152.7 | 116.2
[64, 3, 4, 4, 4] 2x2 stride 2 | 128.6 | 153.3 | 114.6
[64, 3, 4, 4, 4] 2x2 stride 1 | 128.6 | 153.5 | 114.7
[16, 256, 32, 32, 32] 4x4 stride 4 | 3921.7 | 8445.7 | 4064.7
[16, 256, 32, 32, 32] 4x4 stride 2 | 6111.7 | 10630.0 | 6944.4
[16, 256, 32, 32, 32] 4x4 stride 1 | 17938.9 | 22896.8 | 26648.7
[16, 256, 32, 32, 32] 2x2 stride 4 | 3029.6 | 7552.7 | 2840.9
[16, 256, 32, 32, 32] 2x2 stride 2 | 4788.0 | 9322.1 | 5110.5
[16, 256, 32, 32, 32] 2x2 stride 1 | 10363.7 | 14885.9 | 13213.6
[16, 256, 16, 16, 16] 4x4 stride 4 | 606.0 | 1059.1 | 535.9
[16, 256, 16, 16, 16] 4x4 stride 2 | 1037.5 | 1491.5 | 822.3
[16, 256, 16, 16, 16] 4x4 stride 1 | 2835.4 | 3306.8 | 2522.8
[16, 256, 16, 16, 16] 2x2 stride 4 | 438.6 | 892.3 | 369.0
[16, 256, 16, 16, 16] 2x2 stride 2 | 749.2 | 1203.7 | 638.7
[16, 256, 16, 16, 16] 2x2 stride 1 | 1756.1 | 2212.5 | 1547.0
[16, 256, 4, 4, 4] 4x4 stride 4 | 159.6 | 187.6 | 117.6
[16, 256, 4, 4, 4] 4x4 stride 2 | 161.1 | 185.5 | 117.3
[16, 256, 4, 4, 4] 4x4 stride 1 | 160.0 | 148.1 | 117.8
[16, 256, 4, 4, 4] 2x2 stride 4 | 123.9 | 148.3 | 117.6
[16, 256, 4, 4, 4] 2x2 stride 2 | 126.0 | 151.7 | 117.4
[16, 256, 4, 4, 4] 2x2 stride 1 | 127.1 | 152.3 | 117.9
[16, 64, 32, 32, 32] 4x4 stride 4 | 983.5 | 1756.7 | 1067.8
[16, 64, 32, 32, 32] 4x4 stride 2 | 1542.4 | 2315.2 | 1621.5
[16, 64, 32, 32, 32] 4x4 stride 1 | 4498.7 | 5273.4 | 6006.7
[16, 64, 32, 32, 32] 2x2 stride 4 | 767.2 | 1543.4 | 736.7
[16, 64, 32, 32, 32] 2x2 stride 2 | 1207.8 | 1981.5 | 1197.0
[16, 64, 32, 32, 32] 2x2 stride 1 | 2603.3 | 3367.5 | 3161.9
[16, 64, 16, 16, 16] 4x4 stride 4 | 169.5 | 264.6 | 142.8
[16, 64, 16, 16, 16] 4x4 stride 2 | 274.6 | 368.9 | 216.8
[16, 64, 16, 16, 16] 4x4 stride 1 | 723.3 | 820.4 | 643.2
[16, 64, 16, 16, 16] 2x2 stride 4 | 131.4 | 216.0 | 116.1
[16, 64, 16, 16, 16] 2x2 stride 2 | 199.9 | 295.0 | 166.8
```
CC @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80748
Approved by: https://github.com/ngimel