pytorch
1e70d217 - Add error message for complex alpha and non-complex inputs (#54964)

Commit
4 years ago
Add error message for complex alpha and non-complex inputs (#54964) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54964 Previously, the following would error out with a strange error message: ``` import torch x=torch.randn(2) torch.rsub(x, 1, alpha=2j) Traceback (most recent call last) <ipython-input-2-caf2a1c03d0b> in <module> 1 import torch 2 x=torch.randn(2) ----> 3 torch.rsub(x, 1, alpha=2j) RuntimeError: value cannot be converted to type float without overflow: (-0,-2) ``` The reason why this is happening is because the alpha check doesn't check for if `x` is not complex and `alpha` is complex. The error gets thrown further along in the implementation of torch.sub, when it coerces `alpha` to be the same dtype as the input tensor: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L53 This PR fixes the bad error message by adding a new check to the alpha check. Test Plan: - pytest test/test_binary_ufuncs.py - NB: add, sub, and rsub all share the same alpha check. The test only tests it for torch.add, but that should be sufficient. Reviewed By: gchanan Differential Revision: D27504017 Pulled By: zou3519 fbshipit-source-id: 70b9aa75a7a4faaaa93f6ba235cae85998a91697
Author
Parents
Loading