Speedup segmented sort with large nsort
Follow up to gh-77100
Instead of calling `at::arange`, I repurpose the existing kernels to
achieve the same effect. I've also fixed the 2d case bug where
the pointer was advanced by `n` which equals `nsegment * nsort` after
only processing `nsort` elements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77188
Approved by: https://github.com/ngimel