fix silent correctness bug with channels_last usage of upsample cuda kernels (#54744)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54744
Fixes https://github.com/pytorch/pytorch/issues/54590
After the porting the upsample operators to be structured, they now forward memory_format information to the output. This is a problem for the cuda kernels, which are not implemented to deal with `torch.channels_last` memory format. The operators are:
* upsample_nearest2d
* upsample_bilinear2d
* upsample_nearest3d
* upsample_trilinear3d
This fix just allocates a temporary, contiguous output tensor when that happens, writes the results to the temporary and copies the results back to the output tensor.
I held off on adding tests to get the fix out quickly, but I wrote a script and ran some manual tests, that basically just asserts that the outputs are the same for cpu and cuda, for some threshold. I ran it for all 4 operators:
```
import torch
def basically_equal(t1, t2):
epsilon = 1e-4
diffs = torch.abs(t1 - t2)
print(torch.all(diffs < 1e-4))
# upsample 2d
a = torch.arange(48).reshape(2, 2, 3, 4).contiguous(memory_format=torch.channels_last).float()
out_cpu = torch.nn.functional.interpolate(a, scale_factor=2, mode='nearest')
out_cuda = torch.nn.functional.interpolate(a.to('cuda'), scale_factor=2, mode='nearest')
basically_equal(out_cpu, out_cuda.to("cpu"))
out_cpu = torch.nn.functional.interpolate(a, scale_factor=2, mode='bilinear', align_corners=True)
out_cuda = torch.nn.functional.interpolate(a.to('cuda'), scale_factor=2, mode='bilinear', align_corners=True)
basically_equal(out_cpu, out_cuda.to("cpu"))
# upsample 3d
a = torch.arange(96).reshape(2, 2, 2, 3, 4).contiguous(memory_format=torch.channels_last_3d).float()
out_cpu = torch.nn.functional.interpolate(a, scale_factor=3, mode='nearest')
out_cuda = torch.nn.functional.interpolate(a.to('cuda'), scale_factor=3, mode='nearest')
basically_equal(out_cpu, out_cuda.to("cpu"))
out_cpu = torch.nn.functional.interpolate(a, scale_factor=3, mode='trilinear', align_corners=True)
out_cuda = torch.nn.functional.interpolate(a.to('cuda'), scale_factor=3, mode='trilinear', align_corners=True)
basically_equal(out_cpu, out_cuda.to("cpu"))
```
prints
```
tensor(True)
tensor(True)
tensor(True)
tensor(True)
```
One thing that was weird- `upsample_bilinear2d` and `upsample_trilinear3d` were only accurate across cpu/cuda with an epsilon of `1e-4`. That tentatively sounds close enough to say that cuda isn't "wrong" (?), but that's not exactly "equal"... and I also ran the script before my change, and `bilinear2d` and `trilinear3d` were also the same across cpu/cuda with an epsilon of `1e-4`.
Test Plan: Imported from OSS
Reviewed By: ezyang
Differential Revision: D27351393
Pulled By: bdhirsh
fbshipit-source-id: b33f46e4855dc8b49b363770190b639beebbf5a7