pytorch
88208c6f - [inductor][cpp] fix mul for uint8 (#98473)

Commit
1 year ago
[inductor][cpp] fix mul for uint8 (#98473) Fixes #98149 The type of `mul`'s output is not inconsistent with its input. This PR fixes the type of `mul`'s output. Here is the output code for the newly added test case `pow+cos`. `tmp4` is 1024 before fixing and 0 after fixing. #### Before fixing ``` auto tmp0 = in_ptr0[static_cast<long>(0)]; // tmp0 is unsigned_char auto tmp1 = tmp0 * tmp0; // tmp1 is int auto tmp2 = tmp1 * tmp1; // tmp2 is int auto tmp3 = tmp2 * tmp0; // tmp3 is int auto tmp4 = static_cast<float>(tmp3); // tmp4 is float auto tmp5 = std::cos(tmp4); out_ptr0[static_cast<long>(0)] = tmp5; ``` #### After fixing ``` auto tmp0 = in_ptr0[static_cast<long>(0)]; // tmp0 is unsigned_char auto tmp1 = decltype(tmp0)(tmp0 * tmp0); // tmp1 is unsigned_char auto tmp2 = decltype(tmp1)(tmp1 * tmp1); // tmp2 is unsigned_char auto tmp3 = decltype(tmp2)(tmp2 * tmp0); // tmp3 is unsigned_char auto tmp4 = static_cast<float>(tmp3); // tmp4 is float auto tmp5 = std::cos(tmp4); out_ptr0[static_cast<long>(0)] = tmp5; ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/98473 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/jansel
Author
Committer
Parents
Loading