onnxruntime
1f257837 - Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA (#28358)

Commit
2 days ago
Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA (#28358) ## Summary ## Problem The Memory-Efficient Attention (MEA) path crashes with `cudaErrorMisalignedAddress` when: - GQA mode (`q_num_heads != kv_num_heads`) - `head_size != v_head_size` (e.g., Q.head_dim=256, K.head_dim=512) - `seq_len >= 4` (Flash Attention not eligible due to attention mask) This is because MEA's `LaunchUngroup` requires equal head sizes, but the dispatch logic only checked this constraint for the past_key case (line 1380), not the general GQA case. ## Fix Skip MEA for GQA when head sizes differ. The Unfused Attention fallback handles this correctly. ## Affected Models Gemma 4 was not affected. This was a previously incorrect graph. But the fix is still good to have that improves robustness anyways. ~~**Gemma4** (google/gemma-4-e2b-it) with KV sharing:~~ - Layers 15-34 borrow K,V from source layers - Q projection: 1536 → 2048 (8 heads × 256) - K/V from source: [batch, 1, seq, 512] - `head_size = 256`, `v_head_size = 512` ## Testing Minimal repro (from #28357): ```python # Attention(Q=[1,S,2048], K=[1,S,512], V=[1,S,512], q_num_heads=8, kv_num_heads=1) # Before fix: seq=4+ crashes with misaligned address # After fix: all seq lengths work ``` Full Gemma4 decoder (35 layers, 15 GQA + 20 standard Attention): - Prefill seq=32: ✅ - Decode seq=1: ✅ Fixes #28357 Signed-off-by: Justin Chu <justinchu@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Author
Parents
Loading