Improve reshape backward when the op is a view (#28901)
Summary:
Currently, `reshape` does an `as_strided` when the geometry is viewable. However, `as_strided` backward is not very optimized, and can not always detect such cases. Improvements are planned at https://github.com/pytorch/pytorch/pull/8965, and I will finish it some day. But the current situation is that in these cases backward through `reshape` will copy gradient while a simple `view` will not. This is unnecessary.
Notably this affects `flatten` and a whole bunch of other ops implemented on top of `reshape`.
```py
In [15]: x = torch.randn(3, 4, requires_grad=True)
In [16]: y = x.reshape(x.shape)
In [17]: assert y._base is not None
In [18]: gy = torch.randn_like(y)
In [20]: gx = torch.autograd.grad(y, x, gy)[0]
In [21]: gx
Out[21]:
tensor([[ 0.2189, 0.3396, -0.1108, 1.7703],
[ 1.0737, -0.1222, 1.0765, -1.3363],
[-1.3798, -0.2950, 0.0800, 0.2501]])
In [22]: gx._base # not gy
Out[22]:
tensor([ 0.2189, 0.3396, -0.1108, 1.7703, 1.0737, -0.1222, 1.0765, -1.3363,
-1.3798, -0.2950, 0.0800, 0.2501])
In [23]: gy.zero_()
Out[23]:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
In [24]: gx # not sharing storage with gy
Out[24]:
tensor([[ 0.2189, 0.3396, -0.1108, 1.7703],
[ 1.0737, -0.1222, 1.0765, -1.3363],
[-1.3798, -0.2950, 0.0800, 0.2501]])
# but everything is optimized with view, which should be equivalent with reshape in this case
In [25]: y = x.view(x.shape)
In [26]: assert y._base is not None
In [27]: gy = torch.randn_like(y)
In [28]: gx = torch.autograd.grad(y, x, gy)[0]
In [29]: gx
Out[29]:
tensor([[-2.4463, 1.1446, 0.1501, 0.1212],
[-1.1125, 1.4661, 0.9092, -0.2153],
[-0.1937, -0.3381, -1.3883, -0.7329]])
In [30]: gy.zero_()
Out[30]:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
In [31]: gx # sharing storage with gy
Out[31]:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28901
Differential Revision: D18240868
Pulled By: ezyang
fbshipit-source-id: 28fdaa0c7014a9dae6731dfe8b67784d38fc27f0