pytorch
7ec8a4d2 - Vectorized horizontal flip implementation (#88989)

Commit
2 years ago
Vectorized horizontal flip implementation (#88989) When we benchmarked image processing transforms in torchvision : tensor vs pillow we saw that horizontal flip on uint8 data `(3, X, X)` is 2-3x slower. Due to the fact that output's first stride is negative, implementation does a simple data copy using [`basic_loop`](https://github.com/pytorch/pytorch/blob/8371bb8a3dddbead709bc1e9d26715818a34fa8a/aten/src/ATen/native/cpu/Loops.h#L286). In this PR, a vectorized path is added for horizontal flip op for dtypes: uint8, int, float32, long and double and there is a speed-up that reduces the gap between PIL and tensor ops ``` CPU capability usage: AVX2 [----------------------------------------------------------------- Horizontal flip -----------------------------------------------------------------] | torch (1.14.0a0+git2ed1d29) PR | Pillow (9.3.0) | torch (1.14.0.dev20221116+cu116) nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------ channels=3, size=256, dtype=torch.int64 | 101.307 (+-0.904) | | 111.364 (+-0.328) channels=3, size=520, dtype=torch.int64 | 462.369 (+-2.184) | | 505.602 (+-0.541) channels=3, size=712, dtype=torch.int64 | 1855.441 (+-6.528) | | 1828.370 (+-8.600) channels=1, size=256, dtype=torch.int32 | 22.282 (+-0.130) | 44.218 (+-0.936) | 34.651 (+-0.162) channels=1, size=520, dtype=torch.int32 | 72.180 (+-0.076) | 166.639 (+-1.180) | 118.820 (+-0.210) channels=1, size=712, dtype=torch.int32 | 129.621 (+-0.649) | 307.140 (+-2.221) | 216.104 (+-0.793) channels=3, size=256, dtype=torch.uint8 | 51.685 (+-0.200) | 44.171 (+-0.818) | 361.611 (+-0.276) channels=3, size=520, dtype=torch.uint8 | 223.320 (+-0.726) | 166.607 (+-2.256) | 1462.012 (+-4.917) channels=3, size=712, dtype=torch.uint8 | 423.298 (+-1.156) | 307.067 (+-1.999) | 2738.481 (+-1.715) channels=1, size=256, dtype=torch.float32 | 22.281 (+-0.056) | 44.149 (+-0.808) | 35.316 (+-0.028) channels=1, size=520, dtype=torch.float32 | 72.268 (+-0.106) | 166.631 (+-1.212) | 119.504 (+-0.340) channels=1, size=712, dtype=torch.float32 | 129.777 (+-0.632) | 307.078 (+-1.909) | 216.987 (+-0.185) channels=1, size=256, dtype=torch.float16 | 32.789 (+-0.081) | | 34.044 (+-0.039) channels=1, size=520, dtype=torch.float16 | 112.693 (+-0.478) | | 117.445 (+-0.125) channels=1, size=712, dtype=torch.float16 | 203.644 (+-0.791) | | 213.283 (+-0.397) channels=3, size=256, dtype=torch.float64 | 102.058 (+-0.333) | | 108.404 (+-0.346) channels=3, size=520, dtype=torch.float64 | 473.139 (+-1.327) | | 503.265 (+-0.365) channels=3, size=712, dtype=torch.float64 | 1854.489 (+-9.513) | | 1844.345 (+-1.371) channels=1, size=256, dtype=torch.int16 | 11.927 (+-0.056) | | 33.993 (+-0.037) channels=1, size=520, dtype=torch.int16 | 39.724 (+-0.148) | | 117.577 (+-0.153) channels=1, size=712, dtype=torch.int16 | 68.264 (+-0.133) | | 213.118 (+-0.157) Times are in microseconds (us). ``` ``` CPU capability usage: AVX512 [----------------------------------------------------------------- Horizontal flip ------------------------------------------------------------------] | torch (1.14.0a0+git2ed1d29) PR | Pillow (9.3.0) | torch (1.14.0.dev20221118+cu116) nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------- channels=3, size=256, dtype=torch.int64 | 131.244 (+-1.954) | | 135.649 (+-4.066) channels=3, size=520, dtype=torch.int64 | 522.032 (+-4.660) | | 539.822 (+-10.420) channels=3, size=712, dtype=torch.int64 | 1041.111 (+-53.575) | | 1322.411 (+-80.017) channels=1, size=256, dtype=torch.int32 | 10.108 (+-0.414) | 49.164 (+-1.000) | 34.606 (+-0.865) channels=1, size=520, dtype=torch.int32 | 93.218 (+-1.417) | 191.985 (+-5.047) | 133.664 (+-5.372) channels=1, size=712, dtype=torch.int32 | 167.919 (+-2.854) | 353.574 (+-6.568) | 246.162 (+-5.753) channels=3, size=256, dtype=torch.uint8 | 34.710 (+-0.541) | 49.005 (+-0.923) | 136.603 (+-2.339) channels=3, size=520, dtype=torch.uint8 | 154.873 (+-3.049) | 191.729 (+-4.997) | 534.329 (+-10.754) channels=3, size=712, dtype=torch.uint8 | 290.319 (+-4.819) | 351.619 (+-6.978) | 997.119 (+-33.086) channels=1, size=256, dtype=torch.float32 | 10.345 (+-0.338) | 49.105 (+-0.942) | 35.478 (+-0.733) channels=1, size=520, dtype=torch.float32 | 81.131 (+-5.281) | 191.697 (+-4.555) | 133.554 (+-4.193) channels=1, size=712, dtype=torch.float32 | 169.581 (+-3.476) | 352.995 (+-10.792) | 251.089 (+-7.485) channels=1, size=256, dtype=torch.float16 | 35.259 (+-0.612) | | 35.154 (+-0.924) channels=1, size=520, dtype=torch.float16 | 132.407 (+-1.980) | | 131.850 (+-5.611) channels=1, size=712, dtype=torch.float16 | 240.192 (+-5.479) | | 239.555 (+-7.273) channels=3, size=256, dtype=torch.float64 | 129.649 (+-2.349) | | 130.429 (+-6.240) channels=3, size=520, dtype=torch.float64 | 548.534 (+-5.179) | | 622.568 (+-25.720) channels=3, size=712, dtype=torch.float64 | 1208.091 (+-77.095) | | 1679.204 (+-316.292) channels=1, size=256, dtype=torch.int16 | 7.801 (+-0.115) | | 34.517 (+-0.482) channels=1, size=520, dtype=torch.int16 | 36.010 (+-0.855) | | 131.001 (+-1.686) channels=1, size=712, dtype=torch.int16 | 87.395 (+-1.355) | | 237.731 (+-4.181) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/c0421f54c8aed655b042dd1ce4cb621e) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88989 Approved by: https://github.com/lezcano, https://github.com/datumbox, https://github.com/peterbell10, https://github.com/ngimel
Author
Committer
Parents
Loading