CUDA TopK Optimization: use multiple block per slice (#71081)
Summary:
# Overview
Currently the cuda topk implementation uses only 1 block per slice, which limits the performance for big slices. This PR addresses this issue.
There are 2 parts in the topk calculation, find the kth value (`radixFindKthValues`) in each slice, then gather topk values (`gatherTopK`) based on the kth value. `radixFindKthValues` kernel now supports multiple blocks. `gatherTopK` may also need a multiple block version (separate PR?).
kthvalue, quantile, median could also use the same code (separate PR).
# Benchmark
Benchmark result with input `x = torch.randn((D1 (https://github.com/pytorch/pytorch/commit/2d884f226365f94833df91de532e3a31b0db310d), D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be)), dtype=torch.float32)` and `k = 2000` on RTX 3080: https://docs.google.com/spreadsheets/d/1BAGDkTCHK1lROtjYSjuu_nLuFkwfs77VpsVPymyO8Gk/edit?usp=sharing
benchmark plot: left is multiblock, right is dispatched based on heuristics result from the above google sheet.
<p class="img">
<img width=49% src="https://user-images.githubusercontent.com/9999318/150860547-7e450ed2-df09-4292-a02a-cb0e1040eebe.png">
<img width=49% src="https://user-images.githubusercontent.com/9999318/150860579-672b88ca-e500-4846-825c-65d31d126df4.png">
</p>
The performance of divide-and-conquer implementation at https://github.com/pytorch/pytorch/pull/39850 is not stable in terms of the D1 (https://github.com/pytorch/pytorch/commit/2d884f226365f94833df91de532e3a31b0db310d), D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be) size increasing, for more detail please check the above google sheet.
<p>
<img width=49% src="https://user-images.githubusercontent.com/9999318/150860563-21d5a5a3-9d6a-4cef-9031-cac4d2d8edee.png">
</p>
# cubin binary size
The cubin binary size for TensorTopK.cubin (topk) and Sorting.cubin (kthvalue, quantile and etc) has been reduced by removing `#pragma unroll` at [SortingRadixSelect.cuh](https://github.com/pytorch/pytorch/pull/71081/files#diff-df06046dc4a2620f47160e1b16b8566def855c0f120a732e0d26bc1e1327bb90L321) and `largest` template argument without much performance regression.
The final binary size before and after the PR is
```
# master
-rw-rw-r-- 1 richard richard 18M Jan 24 20:07 TensorTopK.cu.1.sm_86.cubin
-rw-rw-r-- 1 richard richard 16M Jan 24 20:07 Sorting.cu.1.sm_86.cubin
# this PR
-rw-rw-r-- 1 richard richard 5.0M Jan 24 20:11 TensorTopK.cu.1.sm_86.cubin
-rw-rw-r-- 1 richard richard 2.5M Jan 24 20:11 Sorting.cu.1.sm_86.cubin
```
script to extract cubin
```
# build with REL_WITH_DEB_INFO=0
# at pytorch directory
cubin_path=build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/cubin; mkdir -p $cubin_path; cd $cubin_path; find ../ -type f -name '*cu.o' -exec cuobjdump {} -xelf all \; ; ls -lh *.cubin -S | head -70
```
# benchmark script
```py
import torch
import time
import torch
import pandas as pd
import numpy as np
import torch.utils.benchmark as benchmark
torch.manual_seed(1)
dtype = torch.float
data = []
for d1 in [1, 20, 40, 60, 80, 100, 200, 400, 800, 1000, 2000, 4000, 6000, 8000, 10000, 100000, 500000]:
if d1 <= 1000:
D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be) = [100, 200, 300, 400, 800, 1000, 2000, 3000, 4000, 5000, 8000, 10000, 20000, 30000, 40000, 80000, 100000, 200000, 300000, 400000, 500000]
else:
D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be) = [100, 200, 300, 400, 800, 1000, 5000, 10000, 20000, 30000]
for d2 in D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be):
k = 2000 if d2 >= 2000 else d2 // 2
print(f"----------------- D1 (https://github.com/pytorch/pytorch/commit/2d884f226365f94833df91de532e3a31b0db310d) = {d1}, D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be) = {d2} -----------------")
try:
x = torch.randn((d1, d2), dtype=dtype, device="cuda")
m = benchmark.Timer(
stmt='x.topk(k=k, dim=1, sorted=False, largest=True)',
globals={'x': x, 'k': k},
num_threads=1,
).blocked_autorange(min_run_time=1)
print(m)
time_ms = m.median * 1000
except RuntimeError: # OOM
time_ms = -1
data.append([d1, d2, k, time_ms])
df = pd.DataFrame(data=data, columns=['D1 (https://github.com/pytorch/pytorch/commit/2d884f226365f94833df91de532e3a31b0db310d)', 'D2 (https://github.com/pytorch/pytorch/commit/9b53d3194c55a2094d0bbf908381ab54f89702be)', 'k', 'time(ms)'])
print(df)
df.to_csv('benchmark.csv')
```
plot script could be found at: https://github.com/yueyericardo/misc/tree/master/share/topk-script
cc zasdfgbnm ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71081
Reviewed By: albanD
Differential Revision: D33823002
Pulled By: ngimel
fbshipit-source-id: c0482664e9d74f7cafc559a07c6f0b564c9e3ed0
(cherry picked from commit be367b8d076aebf53ab7511f6a8a86834c76c95b)