Improve interpolate() speed for channels_last images and masks (#86361)
This PR improves the speed of `interpolate()`:
- on images and masks (`num_channels < 4`, `channels_last=True`)
- for the following modes: linear (antialias=False), nearest (int and float), and nearest-exact (int and float)
- for both upsampling and downsampling
The actual speed-up ranges from 1.1X to 110X, but this depends on various factors like number of threads and of course input_size/output_size. In a typical torchvision ImageNet training job (where num_threads=1 because of DataLoader multi-processing), the following speed-ups should be expected (I ran much more benchmarks than this one, see below for more details):
```
(1, 3, 600, 400) -> (224, 224) linear float32 num_threads=1 1.0X 1.0ms vs 1.0ms
(1, 3, 600, 400) -> (224, 224) nearest float32 num_threads=1 1.9X 0.9ms vs 0.5ms
(1, 3, 600, 400) -> (224, 224) nearest uint8 num_threads=1 1.7X 0.9ms vs 0.5ms
(1, 3, 600, 400) -> (224, 224) nearest-exact float32 num_threads=1 2.1X 1.0ms vs 0.5ms
(1, 3, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=1 1.8X 0.9ms vs 0.5ms
(1, 1, 600, 400) -> (224, 224) linear float32 num_threads=1 7X 0.8ms vs 0.1ms
(1, 1, 600, 400) -> (224, 224) nearest float32 num_threads=1 14X 0.852ms vs 0.061ms
(1, 1, 600, 400) -> (224, 224) nearest uint8 num_threads=1 9X 0.828ms vs 0.087ms
(1, 1, 600, 400) -> (224, 224) nearest-exact float32 num_threads=1 15X 0.922ms vs 0.061ms
(1, 1, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=1 10X 0.897ms vs 0.087ms
```
An immediate follow-up to this PR would be to do the same changes for the 3D kernels.
Thanks a ton @fmassa for the help!
### Speedup benchmarks:
Results:
<details>
```
----------------------------------------------------------------------------------------------------
(1, 3, 64, 64) -> (224, 224) linear float32 num_threads=1 0.9X 0.9ms vs 1.1ms
(1, 3, 64, 64) -> (224, 224) nearest float32 num_threads=1 1.6X 0.9ms vs 0.5ms
(1, 3, 64, 64) -> (224, 224) nearest uint8 num_threads=1 1.7X 0.9ms vs 0.5ms
(1, 3, 64, 64) -> (224, 224) nearest-exact float32 num_threads=1 1.7X 1.0ms vs 0.5ms
(1, 3, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=1 1.9X 0.9ms vs 0.5ms
(1, 1, 64, 64) -> (224, 224) linear float32 num_threads=1 8X 0.806ms vs 0.097ms
(1, 1, 64, 64) -> (224, 224) nearest float32 num_threads=1 15X 0.848ms vs 0.056ms
(1, 1, 64, 64) -> (224, 224) nearest uint8 num_threads=1 10X 0.828ms vs 0.084ms
(1, 1, 64, 64) -> (224, 224) nearest-exact float32 num_threads=1 16X 0.914ms vs 0.057ms
(1, 1, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=1 10X 0.900ms vs 0.086ms
(1, 3, 64, 64) -> (224, 224) linear float32 num_threads=2 1.6X 1.1ms vs 0.7ms
(1, 3, 64, 64) -> (224, 224) nearest float32 num_threads=2 1.6X 0.6ms vs 0.4ms
(1, 3, 64, 64) -> (224, 224) nearest uint8 num_threads=2 1.7X 0.4ms vs 0.3ms
(1, 3, 64, 64) -> (224, 224) nearest-exact float32 num_threads=2 1.7X 0.6ms vs 0.4ms
(1, 3, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=2 1.7X 0.5ms vs 0.3ms
(1, 1, 64, 64) -> (224, 224) linear float32 num_threads=2 9X 0.800ms vs 0.088ms
(1, 1, 64, 64) -> (224, 224) nearest float32 num_threads=2 11X 0.459ms vs 0.043ms
(1, 1, 64, 64) -> (224, 224) nearest uint8 num_threads=2 7X 0.424ms vs 0.064ms
(1, 1, 64, 64) -> (224, 224) nearest-exact float32 num_threads=2 12X 0.503ms vs 0.043ms
(1, 1, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=2 8X 0.461ms vs 0.059ms
(1, 3, 64, 64) -> (224, 224) linear float32 num_threads=12 3X 1.1ms vs 0.3ms
(1, 3, 64, 64) -> (224, 224) nearest float32 num_threads=12 1.6X 0.3ms vs 0.2ms
(1, 3, 64, 64) -> (224, 224) nearest uint8 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 3, 64, 64) -> (224, 224) nearest-exact float32 num_threads=12 1.5X 0.3ms vs 0.2ms
(1, 3, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 1, 64, 64) -> (224, 224) linear float32 num_threads=12 5X 0.8ms vs 0.2ms
(1, 1, 64, 64) -> (224, 224) nearest float32 num_threads=12 10X 0.445ms vs 0.047ms
(1, 1, 64, 64) -> (224, 224) nearest uint8 num_threads=12 7X 0.432ms vs 0.062ms
(1, 1, 64, 64) -> (224, 224) nearest-exact float32 num_threads=12 10X 0.478ms vs 0.046ms
(1, 1, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=12 7X 0.470ms vs 0.063ms
(1, 3, 64, 64) -> (224, 224) linear float32 num_threads=32 3X 1.1ms vs 0.4ms
(1, 3, 64, 64) -> (224, 224) nearest float32 num_threads=32 1.8X 0.3ms vs 0.2ms
(1, 3, 64, 64) -> (224, 224) nearest uint8 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 3, 64, 64) -> (224, 224) nearest-exact float32 num_threads=32 1.4X 0.3ms vs 0.2ms
(1, 3, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 1, 64, 64) -> (224, 224) linear float32 num_threads=32 11X 0.815ms vs 0.074ms
(1, 1, 64, 64) -> (224, 224) nearest float32 num_threads=32 10X 0.443ms vs 0.045ms
(1, 1, 64, 64) -> (224, 224) nearest uint8 num_threads=32 7X 0.436ms vs 0.061ms
(1, 1, 64, 64) -> (224, 224) nearest-exact float32 num_threads=32 10X 0.478ms vs 0.046ms
(1, 1, 64, 64) -> (224, 224) nearest-exact uint8 num_threads=32 8X 0.470ms vs 0.061ms
----------------------------------------------------------------------------------------------------
(1, 3, 128, 128) -> (224, 224) linear float32 num_threads=1 0.9X 0.9ms vs 1.1ms
(1, 3, 128, 128) -> (224, 224) nearest float32 num_threads=1 1.5X 0.9ms vs 0.6ms
(1, 3, 128, 128) -> (224, 224) nearest uint8 num_threads=1 1.7X 0.9ms vs 0.5ms
(1, 3, 128, 128) -> (224, 224) nearest-exact float32 num_threads=1 1.6X 1.0ms vs 0.6ms
(1, 3, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=1 1.8X 0.9ms vs 0.5ms
(1, 1, 128, 128) -> (224, 224) linear float32 num_threads=1 8X 0.808ms vs 0.099ms
(1, 1, 128, 128) -> (224, 224) nearest float32 num_threads=1 15X 0.848ms vs 0.058ms
(1, 1, 128, 128) -> (224, 224) nearest uint8 num_threads=1 9X 0.820ms vs 0.087ms
(1, 1, 128, 128) -> (224, 224) nearest-exact float32 num_threads=1 16X 0.909ms vs 0.059ms
(1, 1, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=1 10X 0.898ms vs 0.088ms
(1, 3, 128, 128) -> (224, 224) linear float32 num_threads=2 1.4X 0.9ms vs 0.7ms
(1, 3, 128, 128) -> (224, 224) nearest float32 num_threads=2 1.5X 0.5ms vs 0.3ms
(1, 3, 128, 128) -> (224, 224) nearest uint8 num_threads=2 1.7X 0.4ms vs 0.3ms
(1, 3, 128, 128) -> (224, 224) nearest-exact float32 num_threads=2 1.5X 0.5ms vs 0.4ms
(1, 3, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=2 1.8X 0.5ms vs 0.3ms
(1, 1, 128, 128) -> (224, 224) linear float32 num_threads=2 9X 0.799ms vs 0.090ms
(1, 1, 128, 128) -> (224, 224) nearest float32 num_threads=2 10X 0.459ms vs 0.045ms
(1, 1, 128, 128) -> (224, 224) nearest uint8 num_threads=2 7X 0.427ms vs 0.059ms
(1, 1, 128, 128) -> (224, 224) nearest-exact float32 num_threads=2 11X 0.501ms vs 0.044ms
(1, 1, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=2 8X 0.460ms vs 0.060ms
(1, 3, 128, 128) -> (224, 224) linear float32 num_threads=12 2.9X 1.0ms vs 0.3ms
(1, 3, 128, 128) -> (224, 224) nearest float32 num_threads=12 1.2X 0.2ms vs 0.2ms
(1, 3, 128, 128) -> (224, 224) nearest uint8 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 3, 128, 128) -> (224, 224) nearest-exact float32 num_threads=12 1.1X 0.2ms vs 0.2ms
(1, 3, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=12 1.6X 0.2ms vs 0.1ms
(1, 1, 128, 128) -> (224, 224) linear float32 num_threads=12 12X 0.809ms vs 0.068ms
(1, 1, 128, 128) -> (224, 224) nearest float32 num_threads=12 11X 0.438ms vs 0.041ms
(1, 1, 128, 128) -> (224, 224) nearest uint8 num_threads=12 8X 0.432ms vs 0.055ms
(1, 1, 128, 128) -> (224, 224) nearest-exact float32 num_threads=12 12X 0.480ms vs 0.041ms
(1, 1, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=12 8X 0.464ms vs 0.056ms
(1, 3, 128, 128) -> (224, 224) linear float32 num_threads=32 3X 1.1ms vs 0.3ms
(1, 3, 128, 128) -> (224, 224) nearest float32 num_threads=32 1.3X 0.3ms vs 0.2ms
(1, 3, 128, 128) -> (224, 224) nearest uint8 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 3, 128, 128) -> (224, 224) nearest-exact float32 num_threads=32 1.4X 0.3ms vs 0.2ms
(1, 3, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 1, 128, 128) -> (224, 224) linear float32 num_threads=32 11X 0.813ms vs 0.075ms
(1, 1, 128, 128) -> (224, 224) nearest float32 num_threads=32 10X 0.443ms vs 0.046ms
(1, 1, 128, 128) -> (224, 224) nearest uint8 num_threads=32 7X 0.433ms vs 0.061ms
(1, 1, 128, 128) -> (224, 224) nearest-exact float32 num_threads=32 10X 0.478ms vs 0.046ms
(1, 1, 128, 128) -> (224, 224) nearest-exact uint8 num_threads=32 8X 0.470ms vs 0.062ms
----------------------------------------------------------------------------------------------------
(1, 3, 224, 224) -> (600, 400) linear float32 num_threads=1 0.9X 4.5ms vs 5.2ms
(1, 3, 224, 224) -> (600, 400) nearest float32 num_threads=1 1.5X 4.2ms vs 2.8ms
(1, 3, 224, 224) -> (600, 400) nearest uint8 num_threads=1 1.8X 4.1ms vs 2.3ms
(1, 3, 224, 224) -> (600, 400) nearest-exact float32 num_threads=1 1.6X 4.5ms vs 2.8ms
(1, 3, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=1 1.9X 4.4ms vs 2.3ms
(1, 1, 224, 224) -> (600, 400) linear float32 num_threads=1 9X 3.8ms vs 0.4ms
(1, 1, 224, 224) -> (600, 400) nearest float32 num_threads=1 17X 4.0ms vs 0.2ms
(1, 1, 224, 224) -> (600, 400) nearest uint8 num_threads=1 11X 3.9ms vs 0.4ms
(1, 1, 224, 224) -> (600, 400) nearest-exact float32 num_threads=1 19X 4.4ms vs 0.2ms
(1, 1, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=1 12X 4.3ms vs 0.4ms
(1, 3, 224, 224) -> (600, 400) linear float32 num_threads=2 1.5X 4.5ms vs 3.1ms
(1, 3, 224, 224) -> (600, 400) nearest float32 num_threads=2 1.4X 2.3ms vs 1.6ms
(1, 3, 224, 224) -> (600, 400) nearest uint8 num_threads=2 1.7X 2.1ms vs 1.2ms
(1, 3, 224, 224) -> (600, 400) nearest-exact float32 num_threads=2 1.6X 2.5ms vs 1.6ms
(1, 3, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=2 1.8X 2.2ms vs 1.2ms
(1, 1, 224, 224) -> (600, 400) linear float32 num_threads=2 15X 3.8ms vs 0.3ms
(1, 1, 224, 224) -> (600, 400) nearest float32 num_threads=2 15X 2.2ms vs 0.1ms
(1, 1, 224, 224) -> (600, 400) nearest uint8 num_threads=2 7X 2.0ms vs 0.3ms
(1, 1, 224, 224) -> (600, 400) nearest-exact float32 num_threads=2 16X 2.4ms vs 0.1ms
(1, 1, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=2 8X 2.2ms vs 0.3ms
(1, 3, 224, 224) -> (600, 400) linear float32 num_threads=12 8X 5.2ms vs 0.7ms
(1, 3, 224, 224) -> (600, 400) nearest float32 num_threads=12 1.3X 0.6ms vs 0.4ms
(1, 3, 224, 224) -> (600, 400) nearest uint8 num_threads=12 1.7X 0.4ms vs 0.2ms
(1, 3, 224, 224) -> (600, 400) nearest-exact float32 num_threads=12 1.4X 0.6ms vs 0.4ms
(1, 3, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=12 1.8X 0.4ms vs 0.2ms
(1, 1, 224, 224) -> (600, 400) linear float32 num_threads=12 36X 3.9ms vs 0.1ms
(1, 1, 224, 224) -> (600, 400) nearest float32 num_threads=12 10X 0.526ms vs 0.051ms
(1, 1, 224, 224) -> (600, 400) nearest uint8 num_threads=12 7X 0.514ms vs 0.069ms
(1, 1, 224, 224) -> (600, 400) nearest-exact float32 num_threads=12 11X 0.569ms vs 0.052ms
(1, 1, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=12 8X 0.557ms vs 0.070ms
(1, 3, 224, 224) -> (600, 400) linear float32 num_threads=32 9X 4.5ms vs 0.5ms
(1, 3, 224, 224) -> (600, 400) nearest float32 num_threads=32 0.5X 0.2ms vs 0.5ms
(1, 3, 224, 224) -> (600, 400) nearest uint8 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 3, 224, 224) -> (600, 400) nearest-exact float32 num_threads=32 1.0X 0.5ms vs 0.5ms
(1, 3, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 1, 224, 224) -> (600, 400) linear float32 num_threads=32 44X 3.864ms vs 0.087ms
(1, 1, 224, 224) -> (600, 400) nearest float32 num_threads=32 10X 0.527ms vs 0.053ms
(1, 1, 224, 224) -> (600, 400) nearest uint8 num_threads=32 7X 0.516ms vs 0.070ms
(1, 1, 224, 224) -> (600, 400) nearest-exact float32 num_threads=32 10X 0.567ms vs 0.055ms
(1, 1, 224, 224) -> (600, 400) nearest-exact uint8 num_threads=32 8X 0.558ms vs 0.072ms
----------------------------------------------------------------------------------------------------
(1, 3, 256, 256) -> (320, 320) linear float32 num_threads=1 1.0X 1.9ms vs 1.9ms
(1, 3, 256, 256) -> (320, 320) nearest float32 num_threads=1 2.0X 1.8ms vs 0.9ms
(1, 3, 256, 256) -> (320, 320) nearest uint8 num_threads=1 1.7X 1.8ms vs 1.0ms
(1, 3, 256, 256) -> (320, 320) nearest-exact float32 num_threads=1 2.1X 1.9ms vs 0.9ms
(1, 3, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=1 1.9X 1.9ms vs 1.0ms
(1, 1, 256, 256) -> (320, 320) linear float32 num_threads=1 9X 1.6ms vs 0.2ms
(1, 1, 256, 256) -> (320, 320) nearest float32 num_threads=1 16X 1.7ms vs 0.1ms
(1, 1, 256, 256) -> (320, 320) nearest uint8 num_threads=1 10X 1.7ms vs 0.2ms
(1, 1, 256, 256) -> (320, 320) nearest-exact float32 num_threads=1 17X 1.9ms vs 0.1ms
(1, 1, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=1 11X 1.8ms vs 0.2ms
(1, 3, 256, 256) -> (320, 320) linear float32 num_threads=2 1.7X 1.9ms vs 1.1ms
(1, 3, 256, 256) -> (320, 320) nearest float32 num_threads=2 2.0X 1.0ms vs 0.5ms
(1, 3, 256, 256) -> (320, 320) nearest uint8 num_threads=2 1.7X 0.9ms vs 0.5ms
(1, 3, 256, 256) -> (320, 320) nearest-exact float32 num_threads=2 2.3X 1.1ms vs 0.5ms
(1, 3, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=2 1.8X 1.0ms vs 0.5ms
(1, 1, 256, 256) -> (320, 320) linear float32 num_threads=2 8X 1.6ms vs 0.2ms
(1, 1, 256, 256) -> (320, 320) nearest float32 num_threads=2 14X 0.931ms vs 0.067ms
(1, 1, 256, 256) -> (320, 320) nearest uint8 num_threads=2 7X 0.9ms vs 0.1ms
(1, 1, 256, 256) -> (320, 320) nearest-exact float32 num_threads=2 15X 1.016ms vs 0.069ms
(1, 1, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=2 9X 0.9ms vs 0.1ms
(1, 3, 256, 256) -> (320, 320) linear float32 num_threads=12 8X 1.9ms vs 0.3ms
(1, 3, 256, 256) -> (320, 320) nearest float32 num_threads=12 1.7X 0.2ms vs 0.1ms
(1, 3, 256, 256) -> (320, 320) nearest uint8 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 3, 256, 256) -> (320, 320) nearest-exact float32 num_threads=12 1.9X 0.2ms vs 0.1ms
(1, 3, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=12 1.6X 0.2ms vs 0.1ms
(1, 1, 256, 256) -> (320, 320) linear float32 num_threads=12 20X 1.630ms vs 0.081ms
(1, 1, 256, 256) -> (320, 320) nearest float32 num_threads=12 10X 0.457ms vs 0.044ms
(1, 1, 256, 256) -> (320, 320) nearest uint8 num_threads=12 7X 0.439ms vs 0.060ms
(1, 1, 256, 256) -> (320, 320) nearest-exact float32 num_threads=12 11X 0.485ms vs 0.045ms
(1, 1, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=12 8X 0.474ms vs 0.061ms
(1, 3, 256, 256) -> (320, 320) linear float32 num_threads=32 8X 1.9ms vs 0.3ms
(1, 3, 256, 256) -> (320, 320) nearest float32 num_threads=32 2.0X 0.2ms vs 0.1ms
(1, 3, 256, 256) -> (320, 320) nearest uint8 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 3, 256, 256) -> (320, 320) nearest-exact float32 num_threads=32 1.4X 0.2ms vs 0.2ms
(1, 3, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=32 1.4X 0.2ms vs 0.1ms
(1, 1, 256, 256) -> (320, 320) linear float32 num_threads=32 21X 1.628ms vs 0.078ms
(1, 1, 256, 256) -> (320, 320) nearest float32 num_threads=32 9X 0.453ms vs 0.048ms
(1, 1, 256, 256) -> (320, 320) nearest uint8 num_threads=32 7X 0.445ms vs 0.063ms
(1, 1, 256, 256) -> (320, 320) nearest-exact float32 num_threads=32 11X 0.535ms vs 0.048ms
(1, 1, 256, 256) -> (320, 320) nearest-exact uint8 num_threads=32 8X 0.502ms vs 0.063ms
----------------------------------------------------------------------------------------------------
(1, 3, 500, 500) -> (800, 800) linear float32 num_threads=1 1.0X 13.8ms vs 14.0ms
(1, 3, 500, 500) -> (800, 800) nearest float32 num_threads=1 1.8X 13.1ms vs 7.4ms
(1, 3, 500, 500) -> (800, 800) nearest uint8 num_threads=1 1.8X 11.1ms vs 6.1ms
(1, 3, 500, 500) -> (800, 800) nearest-exact float32 num_threads=1 1.9X 13.9ms vs 7.4ms
(1, 3, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=1 1.9X 11.8ms vs 6.1ms
(1, 1, 500, 500) -> (800, 800) linear float32 num_threads=1 10X 10.2ms vs 1.1ms
(1, 1, 500, 500) -> (800, 800) nearest float32 num_threads=1 19X 10.8ms vs 0.6ms
(1, 1, 500, 500) -> (800, 800) nearest uint8 num_threads=1 11X 10.4ms vs 0.9ms
(1, 1, 500, 500) -> (800, 800) nearest-exact float32 num_threads=1 20X 11.6ms vs 0.6ms
(1, 1, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=1 12X 11.4ms vs 0.9ms
(1, 3, 500, 500) -> (800, 800) linear float32 num_threads=2 1.8X 13.7ms vs 7.7ms
(1, 3, 500, 500) -> (800, 800) nearest float32 num_threads=2 2.6X 7.3ms vs 2.8ms
(1, 3, 500, 500) -> (800, 800) nearest uint8 num_threads=2 1.8X 5.6ms vs 3.1ms
(1, 3, 500, 500) -> (800, 800) nearest-exact float32 num_threads=2 1.9X 7.9ms vs 4.1ms
(1, 3, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=2 1.9X 6.0ms vs 3.1ms
(1, 1, 500, 500) -> (800, 800) linear float32 num_threads=2 18X 10.1ms vs 0.6ms
(1, 1, 500, 500) -> (800, 800) nearest float32 num_threads=2 19X 5.8ms vs 0.3ms
(1, 1, 500, 500) -> (800, 800) nearest uint8 num_threads=2 10X 5.3ms vs 0.5ms
(1, 1, 500, 500) -> (800, 800) nearest-exact float32 num_threads=2 20X 6.3ms vs 0.3ms
(1, 1, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=2 11X 5.7ms vs 0.5ms
(1, 3, 500, 500) -> (800, 800) linear float32 num_threads=12 8X 13.8ms vs 1.6ms
(1, 3, 500, 500) -> (800, 800) nearest float32 num_threads=12 2.9X 1.5ms vs 0.5ms
(1, 3, 500, 500) -> (800, 800) nearest uint8 num_threads=12 1.7X 1.0ms vs 0.5ms
(1, 3, 500, 500) -> (800, 800) nearest-exact float32 num_threads=12 1.5X 1.5ms vs 1.0ms
(1, 3, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=12 1.8X 1.0ms vs 0.6ms
(1, 1, 500, 500) -> (800, 800) linear float32 num_threads=12 80X 10.1ms vs 0.1ms
(1, 1, 500, 500) -> (800, 800) nearest float32 num_threads=12 13X 0.928ms vs 0.072ms
(1, 1, 500, 500) -> (800, 800) nearest uint8 num_threads=12 8X 0.9ms vs 0.1ms
(1, 1, 500, 500) -> (800, 800) nearest-exact float32 num_threads=12 13X 1.001ms vs 0.074ms
(1, 1, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=12 9X 1.0ms vs 0.1ms
(1, 3, 500, 500) -> (800, 800) linear float32 num_threads=32 18X 14.0ms vs 0.8ms
(1, 3, 500, 500) -> (800, 800) nearest float32 num_threads=32 1.9X 1.0ms vs 0.6ms
(1, 3, 500, 500) -> (800, 800) nearest uint8 num_threads=32 2.9X 0.7ms vs 0.2ms
(1, 3, 500, 500) -> (800, 800) nearest-exact float32 num_threads=32 1.7X 0.9ms vs 0.6ms
(1, 3, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=32 1.8X 0.4ms vs 0.2ms
(1, 1, 500, 500) -> (800, 800) linear float32 num_threads=32 111X 10.254ms vs 0.092ms
(1, 1, 500, 500) -> (800, 800) nearest float32 num_threads=32 14X 0.784ms vs 0.056ms
(1, 1, 500, 500) -> (800, 800) nearest uint8 num_threads=32 7X 0.551ms vs 0.075ms
(1, 1, 500, 500) -> (800, 800) nearest-exact float32 num_threads=32 11X 0.607ms vs 0.057ms
(1, 1, 500, 500) -> (800, 800) nearest-exact uint8 num_threads=32 8X 0.596ms vs 0.076ms
----------------------------------------------------------------------------------------------------
(1, 3, 224, 224) -> (64, 64) linear float32 num_threads=1 1.0X 0.084ms vs 0.084ms
(1, 3, 224, 224) -> (64, 64) nearest float32 num_threads=1 1.0X 0.077ms vs 0.078ms
(1, 3, 224, 224) -> (64, 64) nearest uint8 num_threads=1 1.0X 0.076ms vs 0.076ms
(1, 3, 224, 224) -> (64, 64) nearest-exact float32 num_threads=1 1.0X 0.083ms vs 0.083ms
(1, 3, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=1 1.0X 0.081ms vs 0.082ms
(1, 1, 224, 224) -> (64, 64) linear float32 num_threads=1 1.0X 0.071ms vs 0.071ms
(1, 1, 224, 224) -> (64, 64) nearest float32 num_threads=1 1.0X 0.074ms vs 0.074ms
(1, 1, 224, 224) -> (64, 64) nearest uint8 num_threads=1 1.0X 0.072ms vs 0.072ms
(1, 1, 224, 224) -> (64, 64) nearest-exact float32 num_threads=1 1.0X 0.080ms vs 0.080ms
(1, 1, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=1 0.9X 0.078ms vs 0.084ms
(1, 3, 224, 224) -> (64, 64) linear float32 num_threads=2 1.0X 0.083ms vs 0.084ms
(1, 3, 224, 224) -> (64, 64) nearest float32 num_threads=2 1.0X 0.076ms vs 0.077ms
(1, 3, 224, 224) -> (64, 64) nearest uint8 num_threads=2 1.0X 0.075ms vs 0.074ms
(1, 3, 224, 224) -> (64, 64) nearest-exact float32 num_threads=2 1.0X 0.082ms vs 0.083ms
(1, 3, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=2 1.0X 0.080ms vs 0.083ms
(1, 1, 224, 224) -> (64, 64) linear float32 num_threads=2 1.0X 0.070ms vs 0.071ms
(1, 1, 224, 224) -> (64, 64) nearest float32 num_threads=2 1.0X 0.073ms vs 0.075ms
(1, 1, 224, 224) -> (64, 64) nearest uint8 num_threads=2 1.0X 0.071ms vs 0.072ms
(1, 1, 224, 224) -> (64, 64) nearest-exact float32 num_threads=2 1.0X 0.079ms vs 0.080ms
(1, 1, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=2 1.0X 0.077ms vs 0.079ms
(1, 3, 224, 224) -> (64, 64) linear float32 num_threads=12 1.0X 0.083ms vs 0.084ms
(1, 3, 224, 224) -> (64, 64) nearest float32 num_threads=12 1.0X 0.080ms vs 0.078ms
(1, 3, 224, 224) -> (64, 64) nearest uint8 num_threads=12 1.0X 0.077ms vs 0.075ms
(1, 3, 224, 224) -> (64, 64) nearest-exact float32 num_threads=12 1.0X 0.083ms vs 0.083ms
(1, 3, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=12 1.0X 0.083ms vs 0.082ms
(1, 1, 224, 224) -> (64, 64) linear float32 num_threads=12 1.0X 0.071ms vs 0.071ms
(1, 1, 224, 224) -> (64, 64) nearest float32 num_threads=12 1.0X 0.076ms vs 0.074ms
(1, 1, 224, 224) -> (64, 64) nearest uint8 num_threads=12 1.0X 0.073ms vs 0.071ms
(1, 1, 224, 224) -> (64, 64) nearest-exact float32 num_threads=12 1.0X 0.080ms vs 0.080ms
(1, 1, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=12 1.0X 0.080ms vs 0.078ms
(1, 3, 224, 224) -> (64, 64) linear float32 num_threads=32 1.0X 0.084ms vs 0.084ms
(1, 3, 224, 224) -> (64, 64) nearest float32 num_threads=32 1.0X 0.078ms vs 0.077ms
(1, 3, 224, 224) -> (64, 64) nearest uint8 num_threads=32 1.0X 0.076ms vs 0.076ms
(1, 3, 224, 224) -> (64, 64) nearest-exact float32 num_threads=32 1.0X 0.083ms vs 0.083ms
(1, 3, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=32 1.0X 0.081ms vs 0.082ms
(1, 1, 224, 224) -> (64, 64) linear float32 num_threads=32 1.0X 0.072ms vs 0.072ms
(1, 1, 224, 224) -> (64, 64) nearest float32 num_threads=32 1.0X 0.074ms vs 0.075ms
(1, 1, 224, 224) -> (64, 64) nearest uint8 num_threads=32 1.0X 0.072ms vs 0.072ms
(1, 1, 224, 224) -> (64, 64) nearest-exact float32 num_threads=32 1.0X 0.077ms vs 0.080ms
(1, 1, 224, 224) -> (64, 64) nearest-exact uint8 num_threads=32 1.0X 0.076ms vs 0.079ms
----------------------------------------------------------------------------------------------------
(1, 3, 224, 224) -> (128, 128) linear float32 num_threads=1 1.0X 0.3ms vs 0.3ms
(1, 3, 224, 224) -> (128, 128) nearest float32 num_threads=1 1.8X 0.3ms vs 0.2ms
(1, 3, 224, 224) -> (128, 128) nearest uint8 num_threads=1 1.6X 0.3ms vs 0.2ms
(1, 3, 224, 224) -> (128, 128) nearest-exact float32 num_threads=1 2.0X 0.3ms vs 0.2ms
(1, 3, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=1 1.7X 0.3ms vs 0.2ms
(1, 1, 224, 224) -> (128, 128) linear float32 num_threads=1 6X 0.265ms vs 0.044ms
(1, 1, 224, 224) -> (128, 128) nearest float32 num_threads=1 10X 0.280ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest uint8 num_threads=1 7X 0.273ms vs 0.037ms
(1, 1, 224, 224) -> (128, 128) nearest-exact float32 num_threads=1 11X 0.303ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=1 8X 0.297ms vs 0.038ms
(1, 3, 224, 224) -> (128, 128) linear float32 num_threads=2 1.5X 0.3ms vs 0.2ms
(1, 3, 224, 224) -> (128, 128) nearest float32 num_threads=2 1.8X 0.163ms vs 0.093ms
(1, 3, 224, 224) -> (128, 128) nearest uint8 num_threads=2 1.5X 0.2ms vs 0.1ms
(1, 3, 224, 224) -> (128, 128) nearest-exact float32 num_threads=2 1.9X 0.180ms vs 0.096ms
(1, 3, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=2 1.6X 0.2ms vs 0.1ms
(1, 1, 224, 224) -> (128, 128) linear float32 num_threads=2 6X 0.264ms vs 0.044ms
(1, 1, 224, 224) -> (128, 128) nearest float32 num_threads=2 10X 0.278ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest uint8 num_threads=2 7X 0.270ms vs 0.037ms
(1, 1, 224, 224) -> (128, 128) nearest-exact float32 num_threads=2 11X 0.298ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=2 8X 0.293ms vs 0.037ms
(1, 3, 224, 224) -> (128, 128) linear float32 num_threads=12 1.5X 0.3ms vs 0.2ms
(1, 3, 224, 224) -> (128, 128) nearest float32 num_threads=12 1.7X 0.158ms vs 0.095ms
(1, 3, 224, 224) -> (128, 128) nearest uint8 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 3, 224, 224) -> (128, 128) nearest-exact float32 num_threads=12 1.7X 0.170ms vs 0.100ms
(1, 3, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=12 1.6X 0.2ms vs 0.1ms
(1, 1, 224, 224) -> (128, 128) linear float32 num_threads=12 6X 0.269ms vs 0.043ms
(1, 1, 224, 224) -> (128, 128) nearest float32 num_threads=12 11X 0.291ms vs 0.027ms
(1, 1, 224, 224) -> (128, 128) nearest uint8 num_threads=12 8X 0.281ms vs 0.037ms
(1, 1, 224, 224) -> (128, 128) nearest-exact float32 num_threads=12 11X 0.305ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=12 8X 0.306ms vs 0.038ms
(1, 3, 224, 224) -> (128, 128) linear float32 num_threads=32 1.5X 0.3ms vs 0.2ms
(1, 3, 224, 224) -> (128, 128) nearest float32 num_threads=32 1.6X 0.160ms vs 0.098ms
(1, 3, 224, 224) -> (128, 128) nearest uint8 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 3, 224, 224) -> (128, 128) nearest-exact float32 num_threads=32 1.7X 0.171ms vs 0.099ms
(1, 3, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 1, 224, 224) -> (128, 128) linear float32 num_threads=32 6X 0.269ms vs 0.044ms
(1, 1, 224, 224) -> (128, 128) nearest float32 num_threads=32 10X 0.282ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest uint8 num_threads=32 7X 0.276ms vs 0.037ms
(1, 1, 224, 224) -> (128, 128) nearest-exact float32 num_threads=32 11X 0.305ms vs 0.028ms
(1, 1, 224, 224) -> (128, 128) nearest-exact uint8 num_threads=32 8X 0.299ms vs 0.038ms
----------------------------------------------------------------------------------------------------
(1, 3, 320, 320) -> (256, 256) linear float32 num_threads=1 1.0X 1.2ms vs 1.3ms
(1, 3, 320, 320) -> (256, 256) nearest float32 num_threads=1 2.0X 1.2ms vs 0.6ms
(1, 3, 320, 320) -> (256, 256) nearest uint8 num_threads=1 1.7X 1.1ms vs 0.7ms
(1, 3, 320, 320) -> (256, 256) nearest-exact float32 num_threads=1 2.1X 1.2ms vs 0.6ms
(1, 3, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=1 1.9X 1.2ms vs 0.7ms
(1, 1, 320, 320) -> (256, 256) linear float32 num_threads=1 8X 1.1ms vs 0.1ms
(1, 1, 320, 320) -> (256, 256) nearest float32 num_threads=1 15X 1.109ms vs 0.073ms
(1, 1, 320, 320) -> (256, 256) nearest uint8 num_threads=1 10X 1.1ms vs 0.1ms
(1, 1, 320, 320) -> (256, 256) nearest-exact float32 num_threads=1 16X 1.192ms vs 0.074ms
(1, 1, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=1 11X 1.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) linear float32 num_threads=2 1.7X 1.2ms vs 0.7ms
(1, 3, 320, 320) -> (256, 256) nearest float32 num_threads=2 2.0X 0.6ms vs 0.3ms
(1, 3, 320, 320) -> (256, 256) nearest uint8 num_threads=2 1.7X 0.6ms vs 0.3ms
(1, 3, 320, 320) -> (256, 256) nearest-exact float32 num_threads=2 2.2X 0.7ms vs 0.3ms
(1, 3, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=2 1.8X 0.6ms vs 0.3ms
(1, 1, 320, 320) -> (256, 256) linear float32 num_threads=2 9X 1.0ms vs 0.1ms
(1, 1, 320, 320) -> (256, 256) nearest float32 num_threads=2 11X 0.598ms vs 0.052ms
(1, 1, 320, 320) -> (256, 256) nearest uint8 num_threads=2 8X 0.556ms vs 0.072ms
(1, 1, 320, 320) -> (256, 256) nearest-exact float32 num_threads=2 12X 0.649ms vs 0.053ms
(1, 1, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=2 8X 0.598ms vs 0.073ms
(1, 3, 320, 320) -> (256, 256) linear float32 num_threads=12 5X 1.2ms vs 0.3ms
(1, 3, 320, 320) -> (256, 256) nearest float32 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) nearest uint8 num_threads=12 1.3X 0.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) nearest-exact float32 num_threads=12 1.6X 0.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=12 1.4X 0.2ms vs 0.1ms
(1, 1, 320, 320) -> (256, 256) linear float32 num_threads=12 9X 1.0ms vs 0.1ms
(1, 1, 320, 320) -> (256, 256) nearest float32 num_threads=12 12X 0.572ms vs 0.048ms
(1, 1, 320, 320) -> (256, 256) nearest uint8 num_threads=12 8X 0.560ms vs 0.068ms
(1, 1, 320, 320) -> (256, 256) nearest-exact float32 num_threads=12 13X 0.617ms vs 0.049ms
(1, 1, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=12 9X 0.604ms vs 0.068ms
(1, 3, 320, 320) -> (256, 256) linear float32 num_threads=32 5X 1.2ms vs 0.3ms
(1, 3, 320, 320) -> (256, 256) nearest float32 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) nearest uint8 num_threads=32 1.4X 0.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) nearest-exact float32 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 3, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=32 1.4X 0.2ms vs 0.1ms
(1, 1, 320, 320) -> (256, 256) linear float32 num_threads=32 13X 1.042ms vs 0.081ms
(1, 1, 320, 320) -> (256, 256) nearest float32 num_threads=32 12X 0.586ms vs 0.050ms
(1, 1, 320, 320) -> (256, 256) nearest uint8 num_threads=32 8X 0.562ms vs 0.069ms
(1, 1, 320, 320) -> (256, 256) nearest-exact float32 num_threads=32 12X 0.621ms vs 0.051ms
(1, 1, 320, 320) -> (256, 256) nearest-exact uint8 num_threads=32 9X 0.609ms vs 0.070ms
----------------------------------------------------------------------------------------------------
(1, 3, 600, 400) -> (224, 224) linear float32 num_threads=1 1.0X 1.0ms vs 1.0ms
(1, 3, 600, 400) -> (224, 224) nearest float32 num_threads=1 1.9X 0.9ms vs 0.5ms
(1, 3, 600, 400) -> (224, 224) nearest uint8 num_threads=1 1.7X 0.9ms vs 0.5ms
(1, 3, 600, 400) -> (224, 224) nearest-exact float32 num_threads=1 2.1X 1.0ms vs 0.5ms
(1, 3, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=1 1.8X 0.9ms vs 0.5ms
(1, 1, 600, 400) -> (224, 224) linear float32 num_threads=1 7X 0.8ms vs 0.1ms
(1, 1, 600, 400) -> (224, 224) nearest float32 num_threads=1 14X 0.852ms vs 0.061ms
(1, 1, 600, 400) -> (224, 224) nearest uint8 num_threads=1 9X 0.828ms vs 0.087ms
(1, 1, 600, 400) -> (224, 224) nearest-exact float32 num_threads=1 15X 0.922ms vs 0.061ms
(1, 1, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=1 10X 0.897ms vs 0.087ms
(1, 3, 600, 400) -> (224, 224) linear float32 num_threads=2 1.6X 0.9ms vs 0.6ms
(1, 3, 600, 400) -> (224, 224) nearest float32 num_threads=2 1.9X 0.5ms vs 0.2ms
(1, 3, 600, 400) -> (224, 224) nearest uint8 num_threads=2 1.7X 0.4ms vs 0.3ms
(1, 3, 600, 400) -> (224, 224) nearest-exact float32 num_threads=2 2.1X 0.5ms vs 0.3ms
(1, 3, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=2 1.8X 0.5ms vs 0.3ms
(1, 1, 600, 400) -> (224, 224) linear float32 num_threads=2 10X 0.808ms vs 0.084ms
(1, 1, 600, 400) -> (224, 224) nearest float32 num_threads=2 10X 0.462ms vs 0.046ms
(1, 1, 600, 400) -> (224, 224) nearest uint8 num_threads=2 7X 0.429ms vs 0.062ms
(1, 1, 600, 400) -> (224, 224) nearest-exact float32 num_threads=2 12X 0.504ms vs 0.044ms
(1, 1, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=2 7X 0.461ms vs 0.063ms
(1, 3, 600, 400) -> (224, 224) linear float32 num_threads=12 4X 1.0ms vs 0.2ms
(1, 3, 600, 400) -> (224, 224) nearest float32 num_threads=12 1.7X 0.2ms vs 0.1ms
(1, 3, 600, 400) -> (224, 224) nearest uint8 num_threads=12 1.5X 0.2ms vs 0.1ms
(1, 3, 600, 400) -> (224, 224) nearest-exact float32 num_threads=12 1.9X 0.2ms vs 0.1ms
(1, 3, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=12 1.6X 0.2ms vs 0.1ms
(1, 1, 600, 400) -> (224, 224) linear float32 num_threads=12 12X 0.820ms vs 0.067ms
(1, 1, 600, 400) -> (224, 224) nearest float32 num_threads=12 11X 0.438ms vs 0.041ms
(1, 1, 600, 400) -> (224, 224) nearest uint8 num_threads=12 8X 0.431ms vs 0.056ms
(1, 1, 600, 400) -> (224, 224) nearest-exact float32 num_threads=12 12X 0.482ms vs 0.041ms
(1, 1, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=12 8X 0.467ms vs 0.056ms
(1, 3, 600, 400) -> (224, 224) linear float32 num_threads=32 4X 1.0ms vs 0.3ms
(1, 3, 600, 400) -> (224, 224) nearest float32 num_threads=32 1.7X 0.2ms vs 0.1ms
(1, 3, 600, 400) -> (224, 224) nearest uint8 num_threads=32 1.5X 0.2ms vs 0.1ms
(1, 3, 600, 400) -> (224, 224) nearest-exact float32 num_threads=32 1.8X 0.2ms vs 0.1ms
(1, 3, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 1, 600, 400) -> (224, 224) linear float32 num_threads=32 12X 0.824ms vs 0.070ms
(1, 1, 600, 400) -> (224, 224) nearest float32 num_threads=32 10X 0.443ms vs 0.044ms
(1, 1, 600, 400) -> (224, 224) nearest uint8 num_threads=32 7X 0.438ms vs 0.059ms
(1, 1, 600, 400) -> (224, 224) nearest-exact float32 num_threads=32 11X 0.479ms vs 0.045ms
(1, 1, 600, 400) -> (224, 224) nearest-exact uint8 num_threads=32 8X 0.470ms vs 0.059ms
----------------------------------------------------------------------------------------------------
(1, 3, 800, 800) -> (500, 500) linear float32 num_threads=1 1.0X 4.7ms vs 4.7ms
(1, 3, 800, 800) -> (500, 500) nearest float32 num_threads=1 2.0X 4.4ms vs 2.2ms
(1, 3, 800, 800) -> (500, 500) nearest uint8 num_threads=1 1.8X 4.3ms vs 2.5ms
(1, 3, 800, 800) -> (500, 500) nearest-exact float32 num_threads=1 2.1X 4.7ms vs 2.2ms
(1, 3, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=1 1.9X 4.6ms vs 2.5ms
(1, 1, 800, 800) -> (500, 500) linear float32 num_threads=1 9X 4.0ms vs 0.4ms
(1, 1, 800, 800) -> (500, 500) nearest float32 num_threads=1 17X 4.2ms vs 0.2ms
(1, 1, 800, 800) -> (500, 500) nearest uint8 num_threads=1 11X 4.1ms vs 0.4ms
(1, 1, 800, 800) -> (500, 500) nearest-exact float32 num_threads=1 19X 4.6ms vs 0.2ms
(1, 1, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=1 12X 4.5ms vs 0.4ms
(1, 3, 800, 800) -> (500, 500) linear float32 num_threads=2 1.7X 4.7ms vs 2.7ms
(1, 3, 800, 800) -> (500, 500) nearest float32 num_threads=2 2.1X 2.4ms vs 1.1ms
(1, 3, 800, 800) -> (500, 500) nearest uint8 num_threads=2 1.8X 2.2ms vs 1.3ms
(1, 3, 800, 800) -> (500, 500) nearest-exact float32 num_threads=2 2.3X 2.6ms vs 1.1ms
(1, 3, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=2 1.9X 2.3ms vs 1.3ms
(1, 1, 800, 800) -> (500, 500) linear float32 num_threads=2 15X 4.0ms vs 0.3ms
(1, 1, 800, 800) -> (500, 500) nearest float32 num_threads=2 16X 2.3ms vs 0.1ms
(1, 1, 800, 800) -> (500, 500) nearest uint8 num_threads=2 9X 2.1ms vs 0.2ms
(1, 1, 800, 800) -> (500, 500) nearest-exact float32 num_threads=2 17X 2.5ms vs 0.1ms
(1, 1, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=2 10X 2.3ms vs 0.2ms
(1, 3, 800, 800) -> (500, 500) linear float32 num_threads=12 10X 4.7ms vs 0.5ms
(1, 3, 800, 800) -> (500, 500) nearest float32 num_threads=12 1.9X 0.4ms vs 0.2ms
(1, 3, 800, 800) -> (500, 500) nearest uint8 num_threads=12 1.7X 0.4ms vs 0.2ms
(1, 3, 800, 800) -> (500, 500) nearest-exact float32 num_threads=12 1.9X 0.4ms vs 0.2ms
(1, 3, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=12 1.8X 0.4ms vs 0.2ms
(1, 1, 800, 800) -> (500, 500) linear float32 num_threads=12 41X 3.969ms vs 0.096ms
(1, 1, 800, 800) -> (500, 500) nearest float32 num_threads=12 11X 0.545ms vs 0.051ms
(1, 1, 800, 800) -> (500, 500) nearest uint8 num_threads=12 8X 0.532ms vs 0.070ms
(1, 1, 800, 800) -> (500, 500) nearest-exact float32 num_threads=12 11X 0.590ms vs 0.052ms
(1, 1, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=12 8X 0.578ms vs 0.071ms
(1, 3, 800, 800) -> (500, 500) linear float32 num_threads=32 17X 4.7ms vs 0.3ms
(1, 3, 800, 800) -> (500, 500) nearest float32 num_threads=32 1.8X 0.2ms vs 0.1ms
(1, 3, 800, 800) -> (500, 500) nearest uint8 num_threads=32 2.0X 0.3ms vs 0.1ms
(1, 3, 800, 800) -> (500, 500) nearest-exact float32 num_threads=32 1.9X 0.2ms vs 0.1ms
(1, 3, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=32 1.6X 0.2ms vs 0.1ms
(1, 1, 800, 800) -> (500, 500) linear float32 num_threads=32 45X 4.028ms vs 0.090ms
(1, 1, 800, 800) -> (500, 500) nearest float32 num_threads=32 10X 0.549ms vs 0.053ms
(1, 1, 800, 800) -> (500, 500) nearest uint8 num_threads=32 7X 0.536ms vs 0.072ms
(1, 1, 800, 800) -> (500, 500) nearest-exact float32 num_threads=32 11X 0.592ms vs 0.055ms
(1, 1, 800, 800) -> (500, 500) nearest-exact uint8 num_threads=32 8X 0.581ms vs 0.074ms
```
</details>
Code:
<details>
I used this file which is adapted from https://github.com/pytorch/pytorch/blob/master/benchmarks/operator_benchmark/pt/interpolate_test.py
```py
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for interpolate operator."""
class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
def init(self, input_size, output_size, channels_last=False, mode='linear', dtype=torch.float):
input_image = torch.randint(0, 256, size=input_size, dtype=dtype, device='cpu',
requires_grad=self.auto_set())
if channels_last:
if input_image.ndim == 4:
input_image = input_image.contiguous(memory_format=torch.channels_last)
elif input_image.ndim == 5:
input_image = input_image.contiguous(memory_format=torch.channels_last_3d)
else:
raise ValueError(
f"Can not set channels_last to the input of {input_image.ndim} dims"
)
align_corners = None if "nearest" in mode else False
if mode == "linear":
mode = {
3: 'linear',
4: 'bilinear',
5: 'trilinear',
}[input_image.ndim]
self.inputs = {
"input_image": input_image,
"output_size": output_size,
"mode": mode,
"align_corners": align_corners,
}
self.set_module_name("interpolate")
def forward(self, input_image, output_size, mode, align_corners):
return torch.nn.functional.interpolate(input_image, size=output_size, mode=mode,
align_corners=align_corners)
def make_config():
sizes = (
((224, 224), (64, 64)),
((224, 224), (128, 128)),
((600, 400), (224, 224)),
((320, 320), (256, 256)),
((800, 800), (500, 500)),
)
attrs = []
for (HW1, HW2) in sizes:
attrs.append([(1, 3, *HW1), HW2]) # 3 channels
attrs.append([(1, 1, *HW1), HW2]) # 1 channel
attrs.append([(1, 3, *HW2), HW1]) # 3 channels
attrs.append([(1, 1, *HW2), HW1]) # 1 channel
config = op_bench.config_list(
attr_names=["input_size", "output_size"],
attrs=attrs,
cross_product_configs={
'channels_last': [True],
'mode': ["linear", "nearest", "nearest-exact"],
'dtype': [torch.float, torch.uint8]
},
tags=["short"],
)
# Need to remove instances with both torch.int and linear
# Note: this is naaaasty
def get_mode(l):
for d in l:
if "mode" in d:
return d["mode"]
def get_dtype(l):
for d in l:
if "dtype" in d:
return d["dtype"]
config = [l for l in config if not(get_mode(l) == "linear" and get_dtype(l) == torch.uint8)]
return config
config = make_config()
op_bench.generate_pt_test(config, InterpolateBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
```
with
```
for num_threads in 1 2 12 32; do echo "num_threads=$num_threads" && python -m pt.my_interpolate_test --iterations 1000 --omp_num_threads $num_threads ; done > $out_file
```
and this very ugly helper
```py
import re
with open("main") as f:
main = f.readlines()
with open("new") as f:
new = f.readlines()
out = []
for main_line, new_line in zip(main, new):
if main_line.startswith("num_threads="):
num_threads = int(main_line.split("=")[-1])
if main_line.startswith("# Input"):
deets = f"{main_line.strip()}, {num_threads=}"
if main_line.startswith("Forward"):
main_time = float(main_line.split()[-1])
new_time = float(new_line.split()[-1])
ratio = main_time / new_time
fmt = ".1f" if ratio < 3 else ".0f"
improv = f"{ratio:{fmt}}X"
time_fmt = ",.3f" if new_time < 100 else ",.1f"
deets = deets.strip().replace("# Input: ", "")
deets = deets.replace(": ", "=")
deets = deets.replace("input_size=", "")
deets = deets.replace(", output_size=", " -> ")
deets = deets.replace("dtype=torch.", "")
deets = deets.replace("mode=", "")
deets = deets.replace("channels_last=True, ", "")
split = deets.split(",")
size = ','.join(split[:-3])
mode, dtype, threads = split[-3:]
deets = f"{size:<30} {mode:<15} {dtype:<10} {threads:<15}"
l = f"{deets} {improv:<5} {main_time / 1000:{time_fmt}}ms vs {new_time / 1000:{time_fmt}}ms"
out.append(l)
def key(s):
# s = ''.join(s.split()[1:]) # remove "N.nX" part
num_threads = (int(re.findall(r"num_threads=(\d+)", s)[0]),)
input_shape, output_shape = re.findall("\(.*?\)", s)
input_shape = input_shape[1:-1] # remove parenthesis
input_HW = tuple(int(x) for x in input_shape.split(",")[-2:])
input_C = (-int(input_shape.split(",")[1]),)
output_HW = tuple(int(x) for x in output_shape[1:-1].split(","))
is_downsample = (output_HW[0] < input_HW[0],)
if "linear" in s:
mode = "linear"
elif "nearest-exact" in s:
mode = "nearest-exact"
else:
assert "nearest" in s
mode = "nearest"
mode = (mode,)
return is_downsample + input_HW + output_HW + num_threads + input_C + mode
for i, l in enumerate(sorted(out, key=key)):
if i % 10 == 0 and i % 40 != 0:
print()
if i % 40 == 0:
print("-" * 100)
print(l)
```
</details>
Closes https://github.com/pytorch/pytorch/issues/83840
When this is merged we should be able to remove some hack in vision as well https://github.com/pytorch/vision/pull/6661 (CC @vfdev-5 @datumbox )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86361
Approved by: https://github.com/vfdev-5, https://github.com/datumbox, https://github.com/fmassa