pytorch
57dcb042 - Batched gradient support for view+inplace operations (#47227)

Commit
4 years ago
Batched gradient support for view+inplace operations (#47227) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47227 Motivation ---------- We would like to compute batched gradients for view+inplace operations. This most notably shows up in internal implementation of operations. For example, many view backward functions (SelectBackward, DiagonalBackward) are implemented with view+inplace, so to support vectorized hessian computation for e.g. torch.select and torch.diagonal we would need a way to handle or workaround view+inplace. Approach -------- view+inplace creates a CopySlices node and transmute view backward nodes into an AsStrided node. For example, ``` leaf = torch.randn(4, 5, requires_grad=True) base = leaf * leaf view = base[0] view.cos_() ``` base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward. To support vmap over CopySlices and AsStridedBackward: - We use `new_empty_strided` instead of `empty_strided` in CopySlices so that the batch dims get propagated - We use `new_zeros` inside AsStridedBackward so that the batch dims get propagated. Test Plan --------- - New tests. When we get closer to having most operations support batched grad computation via vmap, I'd like to add it as an option to gradcheck and turn it on for our tests. Test Plan: Imported from OSS Reviewed By: kwanmacher, glaringlee Differential Revision: D24741687 Pulled By: zou3519 fbshipit-source-id: 8210064f782a0a7a193752029a4340e505ffb5d8
Author
Parents
Loading