onnxruntime
d1abad00 - [webgpu] Propagate rotary_interleaved parameter to GQA shader (#26758)

Commit
14 days ago
[webgpu] Propagate rotary_interleaved parameter to GQA shader (#26758) ### Description This PR fixes the last tests that were failing in https://github.com/microsoft/onnxruntime/pull/26715#issuecomment-3626039240, where rotary_interleaved=1 in GQA kernel. The root cause was that the `rotary_interleaved` parameter was not being propagated correctly, meaning it always defaulted to 0 in `FusedQKRotaryEmbeddingProgram`. ``` Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider'] ================================================================================= Prefill_ColdStart | In:3 Past:0 Total:3 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08) Prefill_ColdStart | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Decode_Early | In:1 Past:16 Total:17 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08) Decode_Deep | In:1 Past:64 Total:65 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Speculative_Dec | In:4 Past:20 Total:24 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) Batch_Prefill | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Batch_Decode | In:1 Past:32 Total:33 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) GQA_Prefill | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) GQA_Decode | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) GQA_Batch_Dec | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) MQA_Prefill | In:32 Past:0 Total:32 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) MQA_Decode | In:1 Past:32 Total:33 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) LgBatch_MHA | In:1 Past:16 Total:17 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) LgBatch_GQA | In:1 Past:16 Total:17 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Odd_SeqLen | In:7 Past:13 Total:20 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) Odd_Heads | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) HighHeads_MHA | In:1 Past:32 Total:33 H:32 KV:32 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) HighHeads_GQA | In:1 Past:32 Total:33 H:32 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) HighHeads_MQA | In:1 Past:32 Total:33 H:32 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) LgCtx_Prefill | In:128 Past:0 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 4.17e-07) LgCtx_Decode | In:1 Past:127 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.98e-07) TinyHead_MHA | In:4 Past:4 Total:8 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) TinyHead_GQA | In:4 Past:4 Total:8 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.49e-07) LgHead_MHA | In:2 Past:2 Total:4 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) LgHead_GQA | In:2 Past:2 Total:4 H:4 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Ratio_5_1 | In:1 Past:10 Total:11 H:5 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Ratio_6_2 | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Ratio_6_3 | In:1 Past:10 Total:11 H:6 KV:3 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08) Ratio_12_4 | In:1 Past:10 Total:11 H:12 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Zero_Past | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00) Single_Token_Prefill | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00) Rotary_Cache_Test | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Window_Small | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Window_Large | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07) Window_Decode | In:1 Past:20 Total:21 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Softcap_Enabled | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 5.98e-05) Scale_0.5 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Rotary_Interleaved | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary_Half | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary Interleaved 2 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary_Window | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07) 🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS. ``` ### Motivation and Context cc @qjia7 @guschmue
Author
Parents
Loading