pytorch
91cce4c0 - Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (#96223)

Commit
1 year ago
Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (#96223) We currently use `bitonicSortKVInplace` for sorts of size `n <= 32` but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is also unstable, which forces stable sorts fall back to which is up to 4x slower in this small regime. This PR adds a new kernel `warpMergeSortKVInplace` using `cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all stable sorts with `n < 128`. This results in up to a 2x speedup for unstable sorts and up to 15x for stable sorts, depending on the input geometry. This also doesn't increase the total number of kernels since we are replacing radix-sorts of size 32 and 128. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96223 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading