pytorch
a620512d - Support non-standard bools in CUDA masked_scatter (#79391)

Commit
2 years ago
Support non-standard bools in CUDA masked_scatter (#79391) The failure comes from within the `exclusive_sum` used to calculate the selection index. So, I merged the code paths for `uint8_t` and `bool` masks together, then use `cub::TransformIterator` to get a proper `bool` value for the result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79391 Approved by: https://github.com/mruberry, https://github.com/mikaylagawarecki
Author
Committer
Parents
Loading