onnxruntime
aa416f5c - webgpu: Generalize FlashAttention prefill shared-memory path (#28520)

Commit
45 days ago
webgpu: Generalize FlashAttention prefill shared-memory path (#28520) ## Summary - Remove the `Subgroups` feature requirement from `CanApplyFlashAttention`, enabling flash attention on devices without subgroup support - Generalize the Apple-specific shared-memory prefill path into a `use_shm_path` flag that activates for Apple, NVIDIA, or any device lacking subgroups - Replace `is_apple` shader parameter with `use_shm_path` throughout the WGSL template ## Motivation Two issues exist on the current main branch: 1. **NVIDIA prefill produces incorrect results (regression from #28511):** PR #28511 increased `max_k_step` to 32 for NVIDIA in C++, but the shader's subgroup-based path only has `qk_1..qk_4` (16 hardcoded key indices). When `sg_size=32` (e.g. RTX 5080), the loop steps by 32 but only computes QK for keys 0-15, silently skipping keys 16-31. This produces incorrect attention output for models like phi4. 2. **Flash attention prefill unavailable without Subgroups:** `CanApplyFlashAttention` gates on `context.HasFeature(wgpu::FeatureName::Subgroups)`, forcing devices without subgroup support to fall back to the slower split-reduce 2-kernel path for prefill, even though the Apple shared-memory path in the shader is fully subgroup-free. This PR fixes both issues by routing Apple, NVIDIA, and no-subgroup devices through the loop-based shared-memory path (`use_shm_path`), which naturally handles any `max_k_step` value via `array<q_element_t, max_k_step>` and loop iteration — no hardcoded key count. ## Test plan - [x] Built ORT with WebGPU EP on Windows (Release, VS 2022) - [x] Deployed and ran phi4-graph-prune model: output verified correct ("1+1 equals 2.") - [x] Lint check passed (`lintrunner -a`)
Author
Parents
Loading