Optimize memory usage in logsumexp_out (#51239)
Summary:
Partly fixes https://github.com/pytorch/pytorch/issues/31837.
### Update: This is ready for review.
Currently, `torch.logsumexp(input, out=result)` internally creates 2 intermediate tensors with same shape as `input` tensor. This causes unnecessary OOM problems when tensor size is large.
These tensors come from the following:
1. `self - maxes` will create a new tensor with shape of `self`
2. `at::exp` will create another tensor with the shape of `self`
To get rid of this problem, we can use `(self-maxes).exp_()` that performs exp operation in-place. This would reduce memory need from `~3 x input.shape` to `~2 x input.shape` (`self-maxes` is still there)
I think we can't get rid of having a single intermediate tensor with shape of `input` because of `self - maxes` as we have to keep `self` intact. The only scenario would be to have a `torch.Tensor.logsumexp_` method that can do in-place operations on tensor itself. However, I didn't see any in-place method example for reduction operations, so it might not be a good fit.
This is my first contribution here, please let me know if I'm missing anything!
Thanks!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51239
Reviewed By: anjali411
Differential Revision: D27363147
Pulled By: ezyang
fbshipit-source-id: 696fa8764b74386a80b4aa33104f3f9ca57ed712