pytorch
02550bc1 - Support non-standard bools in CUDA mode (#79393)

Commit
2 years ago
Support non-standard bools in CUDA mode (#79393) Closes #54789 For the `fused_mode` kernel, this just uses `c10::load` but the `apply_mode` function is a bit harder because it uses `thrust`. Instead, I've added a second dedicated path for bool which also only uses 2 thrust calls instead of the normal 6, by exploiting the fact that bools only have two possible values. In the following `timeit` benchmark which calls the `apply_mode` version, I see execution time drop from 16.9 ms to 2.2 ms (which is still terrible, but my main goal is fixing the bool handling). ```python import torch a = torch.randint( 0, 2, size=(100, 4096), device='cuda', dtype=torch.bool) %timeit a.mode(1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79393 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading