onnxruntime
c43ce647 - Beam search TopK improvement (#13594)

Commit
3 years ago
Beam search TopK improvement (#13594) ### Description <!-- Describe your changes. --> TopK in BeamSearch retrieves top 2*beam next tokens based on logit score, specifically computing top [batch, 2*beam] tokens based on score [batch, beam, vocab_size]. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Current implementation use batch as the grid and each thread block compute top 2*beam from [beam, vocab_size]. It is inefficient because: 1. batch size is usually small( <32) and can not fully leverage GPU's SMs; 2. vocab_size is usually more than 50k. It is inefficient to compute 50k * beam in one thread block. This PR split the topk computation into multiple stages: - for small beam size, split [batch, beam, vocab_size] to [batch, beam, parts_of_vocab, vocab_size_per_part] - 1st stage, each thread block compute top 2*beam from vocab_sizer_per_part and gets [batch, beam, parts_of_vocab, 2*beam] - 2nd stage, each thread block compute top 2*beam from parts_of_vocab *(2*beam} and gets [batch, beam, 2*beam] - last stage, compute [batch, 2*beam] from [batch, beam, 2*beam] - for large beam size, 1st stage computes [batch, beam, 2*beam] from [batch, beam, vocab_size] and 2nd stage computes [batch, 2*beam] from [batch, beam, 2*beam]. With the change, performance improves a lot, it reduces ~100us from 2ms for batch:4, beam:4, vocab_size:~50k.
Author
Parents
Loading