onnxruntime
c756e0ab - Improve Shape Inference for GQA (#24143)

Commit
277 days ago
Improve Shape Inference for GQA (#24143) ### Description <!-- Describe your changes. --> For GroupQueryAttention op, if the input total_sequence_length is a constant, we can infer the shape of output present_key/present_value `(batch_size, kv_num_heads, present_sequence_length, head_size)`. https://github.com/microsoft/onnxruntime/blob/5ed900e9712ce2f02e40c15b945d18453d1960d8/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h#L185 We know that from CPU EP, `present_sequence_length = max(past_sequence_length, total_sequence_length)`, and `batch_size, kv_num_heads, head_size` are the same as past_key/past_value. This inference is very important for WebNN EP, because WebNN only supports GQA for `present_sequence_length == past_sequence_length` and requires static shape for graph compilation. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Author
Parents
Loading