onnxruntime
0a478c0d - [Shape Inference] Fix GQA shape inference for present outputs (#27250)

Commit
36 days ago
[Shape Inference] Fix GQA shape inference for present outputs (#27250) ### Description When using pre-allocated KV cache with `freeDimensionOverrides`, the shape inference for `present_key` and `present_value` outputs failed silently. This caused downstream graph operations to receive tensors with unknown dynamic shapes, leading to unexpected fallback in execution providers like WebNN. (WebNN currently doesn't support dynamic shape) ### Motivation and Context **Root cause**: In `BaseGroupQueryAttentionTypeAndShapeInference()`, the shape inference logic for `use_max_past_present_buffer == -1` only propagated shapes when BOTH conditions were met: 1. `total_sequence_length_value` was a concrete value (> 0) 2. `past_dims[2]` had a concrete dimension value When either condition failed (e.g., using `freeDimensionOverrides` which results in dynamic `past_sequence_length`), present output shapes were left uninitialized. Additionally, when `past_key/past_value` is not provided (prefill/first-token mode), no shape inference was performed for present outputs at all. **Fix**: 1. For `use_max_past_present_buffer == -1`: - Always construct and propagate `present_shape` - Compute `present_sequence_length = max(past_sequence_length, total_sequence_length)` when both values are concrete - Fall back to copying `past_key`'s sequence dimension when exact value cannot be computed 2. Add new else-if branch to handle prefill mode (no past_key/past_value input): - Infer `head_size` from query shape and `num_heads/kv_num_heads` attrs - Handle both separate Q/K/V and packed QKV input formats - Construct present shape from query dims, `kv_num_heads`, and `total_sequence_length` or `kv_sequence_length`
Author
Parents
Loading