pytorch
2ec3e803 - Update accumulate_grad to support vmap (#49119)

Commit
4 years ago
Update accumulate_grad to support vmap (#49119) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49119 I don't know how the accumulate_grad code gets hit via calling autograd.grad, so I went through all places in accumulate_grad that are definitely impossible to vmap through and changed them. To support this: - I added vmap support for Tensor::strides(). It returns the strides that correspond to the public dimensions of the tensor (not the ones being vmapped over). - Changed an instance of empty_strided to new_empty_strided. - Replaced an in-place operation in accumulate_grad.h Test Plan: - added a test for calling strides() inside of vmap - added tests that exercise all of the accumulate_grad code path. NB: I don't know why these tests exercise the code paths, but I've verified that they do via gdb. Suggestions for some saner test cases are very welcome. Reviewed By: izdeby Differential Revision: D25563543 Pulled By: zou3519 fbshipit-source-id: 05ac6c549ebd447416e6a07c263a16c90b2ef510
Author
Parents
Loading