pytorch
e8faf697 - fix torch.pow type promotion issue (#54085)

Commit
3 years ago
fix torch.pow type promotion issue (#54085) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54085 Fixes https://github.com/pytorch/pytorch/issues/50121. This fixes two similar issues pointed out with the dtype that `torch.pow` performs its computation. Thanks ngimel for spotting the issues originally (comments [here](https://github.com/pytorch/pytorch/pull/53669#discussion_r594624355) and [here](https://github.com/pytorch/pytorch/pull/53669#discussion_r594719704))! Before: ``` >>> torch.pow(2, torch.tensor([17], dtype=torch.uint8), out=torch.tensor([0])) tensor([0]) >>> torch.pow(2, torch.tensor(17, dtype=torch.uint8), out=torch.tensor(0)) tensor(131072) >>> torch.pow(2, torch.tensor([17], dtype=torch.uint8, device='cuda'), out=torch.tensor([0], device='cuda')) tensor([131072], device='cuda:0') >>> torch.pow(2, torch.tensor(17, dtype=torch.uint8, device='cuda'), out=torch.tensor(0, device='cuda')) tensor(131072, device='cuda:0') ``` After: ``` >>> torch.pow(2, torch.tensor([17], dtype=torch.uint8), out=torch.tensor([0])) tensor([0]) >>> torch.pow(2, torch.tensor(17, dtype=torch.uint8), out=torch.tensor(0)) tensor(0) >>> torch.pow(2, torch.tensor([17], dtype=torch.uint8, device='cuda'), out=torch.tensor([0], device='cuda')) tensor([0], device='cuda:0') >>> torch.pow(2, torch.tensor(17, dtype=torch.uint8, device='cuda'), out=torch.tensor(0, device='cuda')) tensor(0, device='cuda:0') ``` In all four cases above, `tensor(0, ...)` is the correct value because the computed "common dtype" among the inputs is expected to be `uint8`. Computing `2 ** 7` in uint8 will then overflow to zero. Finally, we cast the computed output to the output tensor's dtype, which is `int32`. There were two separate issues fixed in this PR: one for cpu and one for cuda: * For CPU, The `pow(Scalar, Tensor)` overload wasn't calling `set_wrapped_number(true)` after wrapping the scalar in a Tensor, which caused the "promoted" scalar to incorrectly participate in type promotion (see the documented behavior [here](https://github.com/pytorch/pytorch/blob/aa8714dfedc73c67524e2394fe04d115f0783a09/c10/core/TensorImpl.h#L590)) * For CUDA, the cuda kernels defined in `PowKernel.cu` were using the output's dtype to run the computation, instead of the common dtype. As an aside: The CPU and CUDA kernels actually both use `iter.dtype()` instead of `iter.common_dtype()` to run the computation, which I fixed. The reason that only manifested here for CUDA is because TensorIterator has cpu-specific logic to create temporary outputs with the intermediate dtype (shown [here](https://github.com/pytorch/pytorch/blob/aa8714dfedc73c67524e2394fe04d115f0783a09/aten/src/ATen/TensorIterator.cpp#L349)). I'm not sure what the end state is there- I can imagine that being something we're more okay doing for cpu than for cuda, but it also leads to hard-to-track-down inconsistencies between the two like in this case. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27096330 Pulled By: bdhirsh fbshipit-source-id: a7e2909243851625cb3056d1e7abb2383bfe95f2
Author
Parents
Loading