onnxruntime
a1b634cf - Support nonpad kv seqlen within opset 24 Attention (CPU) (#27384)

Commit
25 days ago
Support nonpad kv seqlen within opset 24 Attention (CPU) (#27384) - Support opset 24 Attention - nonpad_kv_seqlen and add tests - Refactor the logic to improve performance and maintainability (GEMM) - Refactor fp16 fallback GEMM branch: Upcast -> GemmEx -> Downcast. --- This pull request refactors and enhances the ONNX Runtime CPU attention operator with a focus on improved GEMM (matrix multiplication) handling for both float32 and MLFloat16 types, and adds support for the new `nonpad_kv_seqlen` input (Opset 24+) to enable more flexible masking. The changes simplify code paths, optimize performance (especially for MLFloat16), and improve maintainability. Key changes include: **1. Unified and Optimized GEMM Handling** - Introduces a new templated `AttentionGemm` function that dispatches GEMM operations for both float and MLFloat16, handling hardware capabilities and providing efficient fallbacks (including upcasting/downcasting for MLFloat16 when necessary). This replaces multiple scattered and duplicated GEMM code paths throughout the attention implementation. [[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R113-R197) [[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L357-R462) [[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L590-R646) **2. Support for nonpad_kv_seqlen (Opset 24+)** - Adds handling for the optional `nonpad_kv_seqlen` input: validates its shape and values, ensures it is not used with past key/value, and applies per-batch masking to attention scores based on the valid key/value sequence length. [[1]](diffhunk://#diff-b41e261f1e5a9ffdc334663560199cb827103c2a8af9ef96d92d37e7d0fd3312R18) [[2]](diffhunk://#diff-b41e261f1e5a9ffdc334663560199cb827103c2a8af9ef96d92d37e7d0fd3312R106-R128) [[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R230) [[4]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R246) [[5]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R473-R481) **3. Code Clean-up and Maintainability** - Removes duplicated and complex branching logic for GEMM and MatMul operations, consolidating them into the new `AttentionGemm` helper for both Q*K and QK*V multiplications. This reduces code complexity and the risk of subtle bugs. [[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L357-R462) [[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L590-R646) **4. Minor Includes and Utility Updates** - Adds necessary includes for `<algorithm>` and `<vector>` to support new logic. These changes collectively improve the performance, clarity, and extensibility of the attention implementation, particularly for models using MLFloat16 and for newer ONNX opsets that require more flexible masking.
Author
Parents
Loading