optimize topk on cpu using parallel and partial sort (#19736)
Summary:
This PR aims at improving `topk()` performance on CPU. This is useful when computing **beam search** during `Transformer` and `BERT`.
Given a tensor x of size `[N, C]`, and we want to apply `x.topk(K)`, the current logic is **sequentially** loop on the dimension of `N` and do **quick select** on the dimension of `C` so as to find out top K elements.
Performance can be further improved from:
- On the dimension of `N`, it can be paralleled
- Maybe a faster sorting algorithm for `topk`. (After a bunch of experimenting, `std::partial_sort` seems to be the most promising)
So i compared 3 versions:
1. vanilla: sequential + quick select
2. reference PR https://github.com/pytorch/pytorch/issues/19737: parallel + quick select
3. this PR: parallel + partial sort
with the following benchmark, on `Xeon 8180, 2*28 cores@2.5 GHz`:
```python
import torch
from time import time
num_iters = 1000
def bench_topk(N=8, C=168560, k=10):
a = torch.randn(N, C)
# warm up
for i in range(100):
torch.topk(a, k)
t = 0
for i in range(num_iters):
a = torch.randn(N, C)
start = time()
value, indice = torch.topk(a, k)
t += time() - start
print("#[%d, %d] times: %f ms" % (N, C, t / num_iters * 1000))
Ns = [10, 20, 30]
Cs = [10000, 20000, 40000, 80000, 160000, 320000]
for n in Ns:
for c in Cs:
bench_topk(N=n, C=c)
```
### vanilla: sequential + quick select
```
#[10, 10000] times: 0.746740 ms
#[10, 20000] times: 1.437399 ms
#[10, 40000] times: 2.832455 ms
#[10, 80000] times: 5.649426 ms
#[10, 160000] times: 11.309466 ms
#[10, 320000] times: 22.798765 ms
#[20, 10000] times: 1.511303 ms
#[20, 20000] times: 2.822024 ms
#[20, 40000] times: 5.564770 ms
#[20, 80000] times: 11.443044 ms
#[20, 160000] times: 22.747731 ms
#[20, 320000] times: 46.234449 ms
#[30, 10000] times: 2.214045 ms
#[30, 20000] times: 4.236179 ms
#[30, 40000] times: 8.418577 ms
#[30, 80000] times: 17.067578 ms
#[30, 160000] times: 33.826214 ms
#[30, 320000] times: 68.109420 ms
```
### reference PR: parallel + quick select
```
#[10, 10000] times: 0.271649 ms
#[10, 20000] times: 0.593016 ms
#[10, 40000] times: 1.133518 ms
#[10, 80000] times: 2.082355 ms
#[10, 160000] times: 4.049928 ms
#[10, 320000] times: 7.321285 ms
#[20, 10000] times: 0.315255 ms
#[20, 20000] times: 0.539054 ms
#[20, 40000] times: 1.000675 ms
#[20, 80000] times: 1.914586 ms
#[20, 160000] times: 4.437122 ms
#[20, 320000] times: 8.822445 ms
#[30, 10000] times: 0.347209 ms
#[30, 20000] times: 0.589947 ms
#[30, 40000] times: 1.102814 ms
#[30, 80000] times: 2.112201 ms
#[30, 160000] times: 5.186837 ms
#[30, 320000] times: 10.523023 ms
```
### this PR: parallel + partial sort
```
#[10, 10000] times: 0.150284 ms
#[10, 20000] times: 0.220089 ms
#[10, 40000] times: 0.521875 ms
#[10, 80000] times: 0.965593 ms
#[10, 160000] times: 2.312356 ms
#[10, 320000] times: 4.759422 ms
#[20, 10000] times: 0.167630 ms
#[20, 20000] times: 0.265607 ms
#[20, 40000] times: 0.471477 ms
#[20, 80000] times: 0.974572 ms
#[20, 160000] times: 3.269645 ms
#[20, 320000] times: 6.538608 ms
#[30, 10000] times: 0.204976 ms
#[30, 20000] times: 0.342833 ms
#[30, 40000] times: 0.589381 ms
#[30, 80000] times: 1.398579 ms
#[30, 160000] times: 3.904077 ms
#[30, 320000] times: 9.681224 ms
```
In summary, `2` is **5x** faster than `vanilla` on average and `3` is **8.6x** faster than `vanilla`.
On `Fairseq Transformer`, the default parameter on dataset `wmt14` would have a `topk` size of `[8, 168560]`, and this operator gets `3x` faster with this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19736
Differential Revision: D16204820
Pulled By: VitalyFedyunin
fbshipit-source-id: ea70562c9149a0d832cf5872a891042ebd74fc63