Beam search TopK improvement (#13594)
### Description
<!-- Describe your changes. -->
TopK in BeamSearch retrieves top 2*beam next tokens based on logit
score, specifically computing top [batch, 2*beam] tokens based on score
[batch, beam, vocab_size].
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Current implementation use batch as the grid and each thread block
compute top 2*beam from [beam, vocab_size]. It is inefficient because:
1. batch size is usually small( <32) and can not fully leverage GPU's
SMs; 2. vocab_size is usually more than 50k. It is inefficient to
compute 50k * beam in one thread block.
This PR split the topk computation into multiple stages:
- for small beam size, split [batch, beam, vocab_size] to [batch, beam,
parts_of_vocab, vocab_size_per_part]
- 1st stage, each thread block compute top 2*beam from
vocab_sizer_per_part and gets [batch, beam, parts_of_vocab, 2*beam]
- 2nd stage, each thread block compute top 2*beam from parts_of_vocab
*(2*beam} and gets [batch, beam, 2*beam]
- last stage, compute [batch, 2*beam] from [batch, beam, 2*beam]
- for large beam size, 1st stage computes [batch, beam, 2*beam] from
[batch, beam, vocab_size] and 2nd stage computes [batch, 2*beam] from
[batch, beam, 2*beam].
With the change, performance improves a lot, it reduces ~100us from 2ms
for batch:4, beam:4, vocab_size:~50k.