Added antialias flag to interpolate (CPU only, bilinear) (#65142)
Summary:
Description:
- Added antialias flag to interpolate (CPU only)
- forward and backward for bilinear mode
- added tests
### Benchmarks
<details>
<summary>
Forward pass, CPU. PTH interpolation vs PIL
</summary>
Cases:
- PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apply vs pears)
- PTH 1 Channel, float32 vs PIL 1 Channel Float
Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112
```
# OMP_NUM_THREADS=1 python bench_interp_aa_vs_pillow.py
Torch config: PyTorch built with:
- GCC 9.3
- C++ Version: 201402
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- CPU capability usage: AVX2
- CUDA Runtime 11.1
- NVCC architecture flags: -gencode;arch=compute_75,code=sm_75
- CuDNN 8.0.5
- Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON,
Num threads: 1
[------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (320, 196) ------------------------]
| Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
channels_first contiguous torch.float32 | 2.9 | 3.1
channels_last non-contiguous torch.float32 | 2.6 | 3.6
Times are in milliseconds (ms).
[------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (460, 220) ------------------------]
| Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
channels_first contiguous torch.float32 | 3.4 | 4.0
channels_last non-contiguous torch.float32 | 3.4 | 4.8
Times are in milliseconds (ms).
[------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 96) -------------------------]
| Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
channels_first contiguous torch.float32 | 1.6 | 1.8
channels_last non-contiguous torch.float32 | 1.6 | 1.9
Times are in milliseconds (ms).
[----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------------------]
| Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
channels_first contiguous torch.float32 | 9.0 | 11.3
channels_last non-contiguous torch.float32 | 8.9 | 12.5
Times are in milliseconds (ms).
[----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------------------]
| Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
channels_first contiguous torch.float32 | 2.1 | 1.8
channels_last non-contiguous torch.float32 | 2.1 | 3.4
Times are in milliseconds (ms).
[--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (320, 196) --------------]
| Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
contiguous torch.float32 | 1.2 | 1.0
Times are in milliseconds (ms).
[--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (460, 220) --------------]
| Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
contiguous torch.float32 | 1.4 | 1.3
Times are in milliseconds (ms).
[--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------]
| Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
contiguous torch.float32 | 719.9 | 599.9
Times are in microseconds (us).
[-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------]
| Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
contiguous torch.float32 | 3.7 | 3.5
Times are in milliseconds (ms).
[-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------]
| Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
contiguous torch.float32 | 834.4 | 605.7
Times are in microseconds (us).
```
</details>
Code is moved from torchvision: https://github.com/pytorch/vision/pull/4208
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65142
Reviewed By: mrshenli
Differential Revision: D32432405
Pulled By: jbschlosser
fbshipit-source-id: b66c548347f257c522c36105868532e8bc1d4c6d