pytorch
9a957043 - Update on "Complex gradcheck logic"

Commit
4 years ago
Update on "Complex gradcheck logic" This PR adds gradcheck for complex. The logic used for complex gradcheck is described in Section 3.5.3 here: https://arxiv.org/pdf/1701.00392.pdf This PR is bc breaking because after this PR, we will save intermediate variables for the backward calculation of mul.Scalar and mul.Tensor and in place modification of input tensor would lead to RuntimeError. ``` mu = torch.ones(1, requires_grad=True) x = torch.empty(1) loss = 0 for i in range(3): x.detach_() x.copy_(mu + i) ft = torch.tensor([float(i)]) multiplied = x * ft s = multiplied.sum() loss += s loss.backward() ``` More concretely, this PR introduces the following changes: 1. Updates get_numerical_jacobian to take as input a scalar value for vector (v). Adds gradcheck logic for C -> C, C-> R, R -> C. For R -> C functions, only the real value of gradient is propagated. 2. Adds backward definition for `torch.complex` and also adds a test to verify the definition added. 3. Updates backward for `mul`, `rsqrt`, `sin`, `cos`, `asin`, `acos`, `sinh`, `cosh`. 4. Adds tests for all `torch.real`, `torch.imag`, `torch.view_as_real`, `torch.view_as_complex`, `torch.conj`. Follow up tasks: 1. Add gradgradcheck for `rsqrt` which would need updating the backward formula for `pow`. 2. Add more thorough tests for R -> C cases. Specifically, add R->C test variants for functions. for e.g., `torch.mul(complex_tensor, real_tensor)` 3. Add back commented test in `common_methods_invocation.py`. 4. Try and find more special case checking for complex gradcheck to make debugging easier. [ghstack-poisoned]
Author
Loading