Support nonpad kv seqlen within opset 24 Attention (CPU) (#27384)
- Support opset 24 Attention - nonpad_kv_seqlen and add tests
- Refactor the logic to improve performance and maintainability (GEMM)
- Refactor fp16 fallback GEMM branch: Upcast -> GemmEx -> Downcast.
---
This pull request refactors and enhances the ONNX Runtime CPU attention
operator with a focus on improved GEMM (matrix multiplication) handling
for both float32 and MLFloat16 types, and adds support for the new
`nonpad_kv_seqlen` input (Opset 24+) to enable more flexible masking.
The changes simplify code paths, optimize performance (especially for
MLFloat16), and improve maintainability.
Key changes include:
**1. Unified and Optimized GEMM Handling**
- Introduces a new templated `AttentionGemm` function that dispatches
GEMM operations for both float and MLFloat16, handling hardware
capabilities and providing efficient fallbacks (including
upcasting/downcasting for MLFloat16 when necessary). This replaces
multiple scattered and duplicated GEMM code paths throughout the
attention implementation.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R113-R197)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L357-R462)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L590-R646)
**2. Support for nonpad_kv_seqlen (Opset 24+)**
- Adds handling for the optional `nonpad_kv_seqlen` input: validates its
shape and values, ensures it is not used with past key/value, and
applies per-batch masking to attention scores based on the valid
key/value sequence length.
[[1]](diffhunk://#diff-b41e261f1e5a9ffdc334663560199cb827103c2a8af9ef96d92d37e7d0fd3312R18)
[[2]](diffhunk://#diff-b41e261f1e5a9ffdc334663560199cb827103c2a8af9ef96d92d37e7d0fd3312R106-R128)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R230)
[[4]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R246)
[[5]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R473-R481)
**3. Code Clean-up and Maintainability**
- Removes duplicated and complex branching logic for GEMM and MatMul
operations, consolidating them into the new `AttentionGemm` helper for
both Q*K and QK*V multiplications. This reduces code complexity and the
risk of subtle bugs.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L357-R462)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L590-R646)
**4. Minor Includes and Utility Updates**
- Adds necessary includes for `<algorithm>` and `<vector>` to support
new logic.
These changes collectively improve the performance, clarity, and
extensibility of the attention implementation, particularly for models
using MLFloat16 and for newer ONNX opsets that require more flexible
masking.