pytorch
190e284b - [Inductor] apply vec float mask on logical comparison ops in cpp (#96502)

Commit
2 years ago
[Inductor] apply vec float mask on logical comparison ops in cpp (#96502) Fix https://github.com/pytorch/pytorch/issues/96446 The root cause is that the logical comparison op works on the integer vector which is later used in the `where` op that expects a float vector. 1. Make sure float vec mask is applied on logical comparison ops. 2. Fix vec int specialization for `to_float_mask`. Assume int mask as input and returns the float mask with reinterpret cast. 3. Add a no-op specialization for `to_float_mask` function with the float vec as input. 4. Pass value instead of ref to `to_float_mask`. Passing by value should be efficient enough. 5. Remove a conditional check `!=0` in `masked()` since `to_float_mask` is guaranteed to return a float mask. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96502 Approved by: https://github.com/EikanWang, https://github.com/XiaobingSuper, https://github.com/jansel
Author
Jiong Gong
Committer
Parents
Loading