Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (#96223)
We currently use `bitonicSortKVInplace` for sorts of size `n <= 32`
but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is
also unstable, which forces stable sorts fall back to which is up to
4x slower in this small regime.
This PR adds a new kernel `warpMergeSortKVInplace` using
`cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all
stable sorts with `n < 128`. This results in up to a 2x speedup for
unstable sorts and up to 15x for stable sorts, depending on the input
geometry.
This also doesn't increase the total number of kernels since we are
replacing radix-sorts of size 32 and 128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96223
Approved by: https://github.com/ngimel