pytorch
95f4cd0b - Implement topk with sort for some cases (#68632)

Commit
2 years ago
Implement topk with sort for some cases (#68632) Summary: Benchmark that compares original implementation and the sort implementation (this code should run on a branch without this patch): ```python import torch import timeit def tune_dtype(f): def ret(*args, **kwargs): for dtype in [torch.int8, torch.half, torch.float, torch.double]: f(*args, **kwargs, dtype=dtype) return ret def tune_slice(f): def ret(*args, **kwargs): slice = 1 while slice <= 256: f(*args, **kwargs, slice=slice) slice *= 2 return ret def tune_slice_size(f): def ret(*args, **kwargs): slice_size = 1 while slice_size <= 1_000_000: f(*args, **kwargs, slice_size=slice_size) slice_size *= 10 return ret def tune_k(f): def ret(*args, slice_size, **kwargs): k = 1 while k <= slice_size: f(*args, **kwargs, k=k, slice_size=slice_size) k *= 10 return ret def topk_with_sort(tensor, k, dim=-1, largest=True): values, indices = tensor.sort(dim=dim, descending=largest) return values.narrow(dim, 0, k), indices.narrow(dim, 0, k) def run50sync(f): for _ in range(50): f() torch.cuda.synchronize() def warmup(): N = 1000000 for i in range(1, N // 10000): torch.randn(i, device='cuda') def benchmark_one(slice, slice_size, k, dtype): input_ = torch.empty((slice, slice_size), dtype=dtype, device="cuda").random_() torch.cuda.synchronize() time = timeit.timeit(lambda: run50sync(lambda: torch.topk(input_, k, dim=1)), number=1) torch.cuda.synchronize() time_sort = timeit.timeit(lambda: run50sync(lambda: topk_with_sort(input_, k, dim=1)), number=1) method = "orig" if time < time_sort else "sort" speedup = time / time_sort print(f"(dtype={dtype}, slice={slice}, slice_size={slice_size}, k={k}) -> (method={method}, speedup={speedup})") if __name__ == "__main__": warmup() tune_dtype(tune_slice(tune_slice_size(tune_k(benchmark_one))))() ``` Benchmark result see next comment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/68632 Reviewed By: dagitses Differential Revision: D32566233 Pulled By: ngimel fbshipit-source-id: f7a508176ef3685b491048c4a6562121c60b8b2a
Author
Parents
Loading