text-generation-inference
c6071749 - Fix mask passed to flashinfer (#3324)

Commit
101 days ago
Fix mask passed to flashinfer (#3324) Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill).
Author
Parents
Loading