Add align_corners option to grid_sample and affine_grid, change default to False (#24929)
Summary:
Resolves: https://github.com/pytorch/pytorch/issues/20785
Addresses https://github.com/pytorch/pytorch/issues/24470 for `affine_grid`
Subsumes and closes: https://github.com/pytorch/pytorch/pull/24878 and likewise closes: https://github.com/pytorch/pytorch/issues/24821
Adds the `align_corners` option to `grid_sample` and `affine_grid`, paralleling the option that was added to `interpolate` in version 0.4.0.
In short, setting `align_corners` to `False` allows these functions to be resolution agnostic.
This ensures, for example, that a grid generated from a neural net trained to warp 1024x1024 images will also work to warp the same image upsampled/downsampled to other resolutions like 512x512 or 2048x2048 without producing scaling/stretching artifacts.
Refer to the documentation and https://github.com/pytorch/pytorch/issues/20785 for more details.
#### BC-Breaking Changes
- **Important**: BC-Breaking change because of new default for `align_corners`
The old functionality can still be achieved by setting `align_corners=True`, but the default is now set to `align_corners=False`, since this is the more correct setting, and since this matches the default setting of `interpolate`.
- **Should not cause BC issues**: BC-Breaking change for pathological use case
2D affine transforms on 1D coordinates and 3D affine transforms on 2D coordinates (that is, when one of the spatial dimensions has an empty span) are ill-defined, and not an intended use case of `affine_grid`. Whereas before, all grid point components along such dimension were set arbitrarily to `-1` (that is, before multiplying be the affine matrix), they are now all set instead to `0`, which is a much more consistent and defensible arbitrary choice. A warning is triggered for such cases.
#### Documentation
- Update `affine_grid` documentation to express that it does indeed support 3D affine transforms. This support was already there but not documented.
- Add documentation warnings for BC-breaking changes in `grid_sample` and `affine_grid` (see above).
#### Refactors
- `affine_grid` no longer dispatches to cuDNN under any circumstances.
The decision point for when the cuDNN `affine_grid_generator` is compatible with the native PyTorch version and when it fails is a headache to maintain (see [these conditions](https://github.com/seung-lab/pytorch/blob/5377478e94a7e2c86181a7686a7a492ecd147815/torch/nn/_functions/vision.py#L7-L8)). The native PyTorch kernel is now used in all cases.
- The kernels for `grid_sample` are slightly refactored to make maintenance easier.
#### Tests
Two new tests are added in `test_nn.py`:
- `test_affine_grid_error_checking` for errors and warnings in `affine_grid`
- `test_affine_grid_3D` for testing `affine_grid`'s 3D functionality. The functionality existed prior to this, but wasn't tested.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24929
Differential Revision: D16949064
Pulled By: ailzhang
fbshipit-source-id: b133ce0d47a2a5b3e2140b9d05fb05fca9140926