onnxruntime
a3e477e0 - Attention(23) CUDA (#26466)

Commit
3 days ago
Attention(23) CUDA (#26466) This pull request introduces significant improvements and expanded support for multi-head attention kernels in ONNX Runtime, particularly focusing on supporting both 3D (`BSNH`) and 4D (`BNSH`) QKV input formats. The changes enhance flexibility, correctness, and maintainability for attention operations across CPU and CUDA implementations. ### Expanded QKV Input Format Support * Added support for 4D QKV input format (`Q_K_V_BNSH`) in CUDA attention kernels, including proper handling for both cases with and without past/present states, and enforcing that bias is not supported for this format. This includes logic to avoid unnecessary transposes and to write outputs directly when possible. [[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R264-R265) [[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R343-R354) [[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R388-L388) [[4]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R426-R435) [[5]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716) [[6]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R747-R748) [[7]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791) ### Kernel and Operator Documentation Updates * Updated `OperatorKernels.md` to document the new `Attention` operator inputs and outputs for both 3D and 4D formats, specifying supported tensor types for each input. ### Correctness and Consistency Fixes * Fixed the computation of causal attention indices in CUDA softmax kernels by clarifying and correcting the offset calculation for causal masking. [[1]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL168-R168) [[2]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL244-R244) [[3]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL336-R336) [[4]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL442-R442) * Updated workspace allocation logic for QKV preparation to ensure correct workspace usage for new formats. ### Attention Parameter and Helper Refactoring * Added `is_output_bnsh` field to `AttentionParameters` to indicate output format and updated logic to use this for output placement and transposition decisions. [[1]](diffhunk://#diff-e742290164e1e1fa0152840db2a1b83354e153153df19a2762b58655e49b7f9bR37) [[2]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791) * Refactored CPU attention implementation to use the new `attention_helper` namespace for output mode enums and output shape computation, improving code clarity and maintainability. [[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R5) [[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L118-R125) [[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L143-R149) ### Minor Cleanups * Removed outdated asserts and improved debug output strings for QKV preparation functions to clarify format and state handling. [[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L254) [[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L363) [[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716) These changes collectively improve the flexibility, correctness, and maintainability of attention kernel implementations in ONNX Runtime, especially for advanced transformer models and large language model workloads. **NOT supported in this PR** - Boolean mask - GQA - Softcap - Softmax precision - qk_output_mode other than -1 and 0
Author
Parents
Loading