onnxruntime
3b007a68 - webgpu: Support QKV bias in FlashAttention for MultiHeadAttention (#28380)

Commit
31 days ago
webgpu: Support QKV bias in FlashAttention for MultiHeadAttention (#28380) ## Summary - Remove the `bias == nullptr` requirement from `CanApplyFlashAttention`, enabling FlashAttention for MultiHeadAttention nodes with QKV bias (e.g., whisper decoder). - Apply `TransferBSDToBNSH` to add bias and transpose Q/K/V to BNSH format before calling FlashAttention. - Handle cross-attention (only Q needs bias+transpose, K/V already BNSH from encoder) and self-attention (all Q/K/V need bias+transpose) separately. ## Motivation Whisper decoder's MultiHeadAttention nodes all have QKV bias, which previously forced them into the slower unfused attention path. Enabling FlashAttention for these nodes yields ~45% speedup on whisper-tiny-int4 (~92 → ~134 tokens/s). ## Test plan - [x] Existing MHA unit tests with bias data now exercise the FlashAttention path on WebGPU with Subgroups support - [x] whisper-tiny-int4 end-to-end: correct transcription at ~134 tps (vs ~92 tps baseline) - [x] clang-format passes - [x] D3D12 build succeeds
Author
Parents
Loading