onnxruntime
3c94f1cc - [WebGPU] Fix MHA to ignore past key/value when no present outputs requested (#28027)

Commit
5 days ago
[WebGPU] Fix MHA to ignore past key/value when no present outputs requested (#28027) ### Description When MultiHeadAttention has only 1 output (no present_key/present_value outputs), past key/value inputs should be completely ignored, matching CPU EP semantics. The WebGPU EP was passing pastKey/pastValue TensorViews to shader creation functions even when outputCount <= 1, which affected shader cache keys and allowed past data to leak into the attention computation. This caused the test "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue" to fail with output [17,18,19,20] (pastValue data) instead of expected [9,10,11,12] (V data). The failing output matches exactly what happens when past IS used: Q·pastKey=75 dominates Q·K=35, so softmax gives ~100% weight to pastValue. ### Fix In `applyAttention()`, introduce `effectivePastKey`/`effectivePastValue` that are set to `undefined` when `outputCount <= 1`. All downstream usage (shader creation, input arrays) uses these effective values instead of the raw parameters. This ensures: - Shader cache keys correctly reflect the "no past" configuration - Past tensors are never passed to any shader creation function - Behavior matches CPU EP (which ignores past when present outputs are null) - GQA is unaffected (always has outputCount >= 3) - Vanilla Attention is unaffected (always passes undefined for past)
Author
Parents
Loading