pytorch
b7a5c793 - [inductor] Fix type inference in CPU masked operations (#93842)

Commit
1 year ago
[inductor] Fix type inference in CPU masked operations (#93842) Fixes #93351 The existing code guesses that `tmp3` is probably a `float`, and so truncates any `double` values ```cpp float tmp3 = 0.0; if(tmp2) { auto tmp4 = in_ptr0[i0]; tmp3 = tmp4; } ``` The proposed change is to generate a lambda expression that represents the body of the masked operation, and infer the type from the return value: ```cpp auto tmp3 = [&] { auto tmp4 = in_ptr0[i0]; return tmp4; } ; auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93842 Approved by: https://github.com/jgong5, https://github.com/Valentine233, https://github.com/jansel
Author
Committer
Parents
Loading