pytorch
de44a50f - index_backward: use out-of-place index_put if any input is subclass (#71779)

Commit
3 years ago
index_backward: use out-of-place index_put if any input is subclass (#71779) Summary: Reference: https://github.com/pytorch/functorch/issues/393 Context : The derivative of `__getitem__`/`index` is https://github.com/pytorch/pytorch/blob/f5a71ec2d6956a1ba393657f4b8297bb931eeec2/tools/autograd/derivatives.yaml#L733-L734 where `index_backward` is defined as https://github.com/pytorch/pytorch/blob/f5a71ec2d6956a1ba393657f4b8297bb931eeec2/torch/csrc/autograd/FunctionsManual.cpp#L3892-L3894 Problem arises when `grad` is not BatchedTensor but one of the other input is. In that case, `grad.new_zeros` returns an unbatched tensor and call to the inplace `_index_put_impl_` errors as it expects `zeros_like_self` to be Batched. To avoid this, we dispatch to out-of-place `index_put` if any of the input tensor is subclassed otherwise we dispatch to the inplace `_index_put_impl_`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71779 Reviewed By: albanD Differential Revision: D33790596 Pulled By: zou3519 fbshipit-source-id: 9d6d81b758740cab7b3db9b905f1e8053f82b835 (cherry picked from commit ba0407a86ef3cabf885cd127649fa6dcd7f75117)
Author
Committer
Parents
Loading