Update accumulate_grad to support vmap (#49119)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49119
I don't know how the accumulate_grad code gets hit via calling
autograd.grad, so I went through all places in accumulate_grad
that are definitely impossible to vmap through and changed them.
To support this:
- I added vmap support for Tensor::strides(). It returns the strides
that correspond to the public dimensions of the tensor (not the ones
being vmapped over).
- Changed an instance of empty_strided to new_empty_strided.
- Replaced an in-place operation in accumulate_grad.h
Test Plan:
- added a test for calling strides() inside of vmap
- added tests that exercise all of the accumulate_grad code path.
NB: I don't know why these tests exercise the code paths, but I've
verified that they do via gdb.
Suggestions for some saner test cases are very welcome.
Reviewed By: izdeby
Differential Revision: D25563543
Pulled By: zou3519
fbshipit-source-id: 05ac6c549ebd447416e6a07c263a16c90b2ef510