Implement topk with sort for some cases (#68632)
Summary:
Benchmark that compares original implementation and the sort implementation (this code should run on a branch without this patch):
```python
import torch
import timeit
def tune_dtype(f):
def ret(*args, **kwargs):
for dtype in [torch.int8, torch.half, torch.float, torch.double]:
f(*args, **kwargs, dtype=dtype)
return ret
def tune_slice(f):
def ret(*args, **kwargs):
slice = 1
while slice <= 256:
f(*args, **kwargs, slice=slice)
slice *= 2
return ret
def tune_slice_size(f):
def ret(*args, **kwargs):
slice_size = 1
while slice_size <= 1_000_000:
f(*args, **kwargs, slice_size=slice_size)
slice_size *= 10
return ret
def tune_k(f):
def ret(*args, slice_size, **kwargs):
k = 1
while k <= slice_size:
f(*args, **kwargs, k=k, slice_size=slice_size)
k *= 10
return ret
def topk_with_sort(tensor, k, dim=-1, largest=True):
values, indices = tensor.sort(dim=dim, descending=largest)
return values.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def run50sync(f):
for _ in range(50):
f()
torch.cuda.synchronize()
def warmup():
N = 1000000
for i in range(1, N // 10000):
torch.randn(i, device='cuda')
def benchmark_one(slice, slice_size, k, dtype):
input_ = torch.empty((slice, slice_size), dtype=dtype, device="cuda").random_()
torch.cuda.synchronize()
time = timeit.timeit(lambda: run50sync(lambda: torch.topk(input_, k, dim=1)), number=1)
torch.cuda.synchronize()
time_sort = timeit.timeit(lambda: run50sync(lambda: topk_with_sort(input_, k, dim=1)), number=1)
method = "orig" if time < time_sort else "sort"
speedup = time / time_sort
print(f"(dtype={dtype}, slice={slice}, slice_size={slice_size}, k={k}) -> (method={method}, speedup={speedup})")
if __name__ == "__main__":
warmup()
tune_dtype(tune_slice(tune_slice_size(tune_k(benchmark_one))))()
```
Benchmark result see next comment.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68632
Reviewed By: dagitses
Differential Revision: D32566233
Pulled By: ngimel
fbshipit-source-id: f7a508176ef3685b491048c4a6562121c60b8b2a