pytorch
3607478e - Conjugate View (#54987)

Commit
3 years ago
Conjugate View (#54987) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54987 Based off of ezyang (https://github.com/pytorch/pytorch/pull/44799) and bdhirsh (https://github.com/pytorch/pytorch/pull/43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28227315 Pulled By: anjali411 fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
Author
Parents
Loading