onnxruntime
bf71213a - Support boolean attention mask in Attention(23) CUDA - MHA case (#27428)

Commit
13 days ago
Support boolean attention mask in Attention(23) CUDA - MHA case (#27428) Replace and reland https://github.com/microsoft/onnxruntime/pull/27129 Comparison between this PR approach and inline in softmax ## Tradeoffs | Category | Pre-conversion (current) | Inline in softmax | | :--- | :--- | :--- | | **Memory** | Extra buffer ($num\_elements \times sizeof(T)$) | None — reads 1-byte bool directly | | **Kernel launches** | +1 simple elementwise kernel | Zero extra | | **Code complexity** | 3 files, ~40 lines added | 6+ kernel templates, macros, dispatch logic, data structs | | **Risk** | Low — softmax path untested | High — modifying battle-tested softmax kernels used by MHA + GQA contrib ops | | **Perf impact** | Negligible — mask is small vs. QKV; conversion is memory-bound and fast | Slightly better theoretical bandwidth | | **Maintainability** | Clean separation of concerns | Adds template dimension across all softmax variants | --- This pull request enhances the ONNX Runtime CUDA Attention operator to support boolean attention masks (bool masks) in the Multi-Head Attention (MHA) path, converting them to additive attention bias on the GPU. It also improves test coverage to ensure correctness and parity with the CPU implementation. The main changes include implementing a CUDA kernel for mask conversion, updating the operator logic to handle bool masks, clarifying broadcasting rules, and adding comprehensive unit tests. **CUDA Attention Operator Improvements:** * Implemented a CUDA kernel (`LaunchConvertBoolMaskToAttentionBias`) that converts boolean attention masks to additive bias (True → 0.0, False → mask_filter_value) for the MHA path, ensuring efficient GPU execution. [[1]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49R148-R187) [[2]](diffhunk://#diff-8aa9a15a92d7dc138346dce5de055911895d940ba2183b4ba45bd95ac0e5bfc9R55-R66) * Updated `attention.cc` to use this kernel, correctly handle bool masks in the MHA path, and clarified the broadcasting logic and mask shape interpretation for both GQA and MHA. [[1]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffR6) [[2]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffR380-R383) [[3]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL514-L522) [[4]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL549-R557) [[5]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffR595-R616) **Testing and Documentation Enhancements:** * Added new test cases and a dedicated test class to validate the correctness of boolean mask handling in the MHA path, ensuring parity with the CPU implementation for 2D, 3D, and 4D mask shapes. [[1]](diffhunk://#diff-801fbbcf2537e8e13a0202e6a0f7e88c56ab5aa72d17d949a5556355694b2b2dR563-R725) [[2]](diffhunk://#diff-801fbbcf2537e8e13a0202e6a0f7e88c56ab5aa72d17d949a5556355694b2b2dR893-R922) * Improved comments and documentation in both code and tests to clarify ONNX broadcasting rules and mask shape expectations for different attention paths. [[1]](diffhunk://#diff-4ed1461afda0d3804a61ba95a64b2a84d0c1395f9c887d1a3fdfed914ade22c1L208-R221) [[2]](diffhunk://#diff-801fbbcf2537e8e13a0202e6a0f7e88c56ab5aa72d17d949a5556355694b2b2dR35) **Test Coverage and Reliability:** * Enabled CUDA-based tests for boolean mask scenarios previously only tested on CPU, and adjusted test logic to ensure correct handling of edge cases (e.g., all-false masks). [[1]](diffhunk://#diff-3ff6dfa2ce407ae0073009174c37d1756509e8bbc434dee7c44cd55a996bb777L477-R480) [[2]](diffhunk://#diff-3ff6dfa2ce407ae0073009174c37d1756509e8bbc434dee7c44cd55a996bb777L620-R623) These changes make the CUDA Attention operator more robust and feature-complete, aligning its behavior with the CPU implementation and ONNX specifications.
Author
Parents
Loading