Batched gradient support for view+inplace operations (#47227)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47227
Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.
Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,
```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```
base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.
To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.
Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.
Test Plan: Imported from OSS
Reviewed By: kwanmacher, glaringlee
Differential Revision: D24741687
Pulled By: zou3519
fbshipit-source-id: 8210064f782a0a7a193752029a4340e505ffb5d8