pytorch
7d78a6fc - Update interpolate to use new upsample overloads (#43025)

Commit
4 years ago
Update interpolate to use new upsample overloads (#43025) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43025 - Use new overloads that better reflect the arguments to interpolate. - More uniform interface for upsample ops allows simplifying the Python code. - Also reorder overloads in native_functions.yaml to give them priority. Pull Request resolved: https://github.com/pytorch/pytorch/pull/37177 ghstack-source-id: 106938111 Test Plan: test_nn has pretty good coverage. Relying on CI for ONNX, etc. Didn't test FC because this change is *not* forward compatible. To ensure backwards compatibility, I ran this code before this change ```python def test_func(arg): interp = torch.nn.functional.interpolate with_size = interp(arg, size=(16,16)) with_scale = interp(arg, scale_factor=[2.1, 2.2], recompute_scale_factor=False) with_compute = interp(arg, scale_factor=[2.1, 2.2]) return (with_size, with_scale, with_compute) traced_func = torch.jit.trace(test_func, torch.randn(1,1,1,1)) sample = torch.randn(1, 3, 7, 7) output = traced_func(sample) assert not torch.allclose(output[1], output[2]) torch.jit.save(traced_func, "model.pt") torch.save((sample, output), "data.pt") ``` then this code after this change ```python model = torch.jit.load("model.pt") sample, golden = torch.load("data.pt") result = model(sample) for r, g in zip(result, golden): assert torch.allclose(r, g) ``` Reviewed By: AshkanAliabadi Differential Revision: D21209991 fbshipit-source-id: 5b2ebb7c3ed76947361fe532d1dbdd6faa3544c8
Author
Parents
Loading