pytorch
367488b0 - Move where cuda implementation to TensorIterator (#32984)

Commit
4 years ago
Move where cuda implementation to TensorIterator (#32984) Summary: `where` is special because the arguments do not have the same type, which does not satisfy the assumption in modern https://github.com/pytorch/pytorch/pull/32383. I migrate it to TensorIterator so that there is something to test that this case is not broken. Currently, this case fallback to using legacy (not vectorized, not unrolled) code. It should be supported in the future when I cleanup `Loops.cuh`. I also move some sharing part of `CUDALoops.cuh` and `ROCmLoops.cuh` into `Loops.cuh` so that to logic for checking whether `func_t` has the same arg types could be shared. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32984 Differential Revision: D19825127 Pulled By: ngimel fbshipit-source-id: bbf4682349d96b4480c4d657f3c18a3a67a9bf17
Author
Parents
Loading