onnxruntime
549d7415 - [webgpu] Optimize Attention by enhancing flash attention support (#26715)

Commit
9 days ago
[webgpu] Optimize Attention by enhancing flash attention support (#26715) This pull request improves the WebGPU BERT attention implementation by enhancing FlashAttention support, generalizing tensor layout handling, and increasing batch size flexibility. The changes focus on supporting both BSNH and BNSH tensor layouts, enabling FlashAttention for multi-batch scenarios, and ensuring correct broadcasting and dispatch sizing for attention bias and batch dimensions. Key improvements include: **FlashAttention Support & Generalization:** * Added support for both BSNH and BNSH tensor layouts by introducing the `q_BNSH` parameter and updating shader code, program classes, and kernel logic to handle either layout correctly. This includes changes in the WGSL template and C++ logic for offset calculations and program instantiation. [[1]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fR7) [[2]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fL45-R97) [[3]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fL86-R122) [[4]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R445) [[5]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R454) [[6]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R76) [[7]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R86) [[8]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R110) * Updated the `CanApplyFlashAttention` and `ApplyFlashAttention` logic to allow multi-batch operation by removing the restriction to batch size 1 and ensuring present key/value tensors are always created for FlashAttention. [[1]](diffhunk://#diff-1ed746fa440247995dabd97ad1f318a548fc385cde70b9ea2d4a410219f91629R740-R752) [[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L501-L506) [[3]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L177-R185) **Batch & Bias Handling:** * Modified dispatch group size calculations and uniform variables throughout the FlashAttention pipeline to properly account for batch size, ensuring correct parallelization for multi-batch scenarios. [[1]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R260-R273) [[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L272-R285) [[3]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L320-R333) [[4]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L366-R379) [[5]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L454-R490) [[6]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R95-R100) [[7]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L123-R131) * Added logic to extract and pass attention bias dimensions as uniforms for correct broadcasting in both the compute and shader code. [[1]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R260-R273) [[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L272-R285) [[3]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L454-R490) [[4]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R95-R100) [[5]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L123-R131) **Other Enhancements:** * Improved handling of QKV format detection and generalized code to support more format variants in `CopyKVCache`. * Updated includes and dependencies to ensure all necessary headers for FlashAttention are present. These changes collectively make the WebGPU BERT attention implementation more robust, flexible, and performant across different tensor layouts and batch sizes. phi-4-mm-vision.onnx Before Kernel | Time (ms) | Percentage (%) -- | -- | -- Attention\|AttentionProbs | 159.66 | 11.14 Attention\|VxAttentionScore | 122.56 | 8.55 Attention\|InPlaceSoftmax | 51.83 | 3.62 After Kernel | Time (ms) | Percentage (%) -- | -- | -- Attention\|FlashAttention | 60.23 | 5.38
Author
Parents
Loading