Support Attention(24)-CUDA and disjoint from contrib op (#27542)
Fix #27485
Reference https://github.com/microsoft/onnxruntime/pull/27486
## Summary
Implements CUDA kernel support for the **ONNX Attention op (opset 23 and
24)** using a thin-dispatcher architecture that directly calls Flash
Attention, Memory Efficient Attention (cutlass FMHA), and Unfused
(GEMM+softmax+GEMM) kernels — bypassing the contrib MHA/GQA dispatch
layers entirely. This covers both Multi-Head Attention (MHA) and
Grouped-Query Attention (GQA) variants, with support for internal KV
cache (past/present), external KV cache (`nonpad_kv_seqlen`), boolean
and float attention masks, combined `nonpad_kv_seqlen + attn_mask`
composition, and 3D (BSNH) / 4D (BNSH) input formats.
The kernel respects `ORT_DISABLE_FLASH_ATTENTION` and
`ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION` environment variables,
consistent with contrib MHA/GQA behavior.
---
## Architecture: Thin-Dispatcher Design
The core kernel (`attention.cc`) implements a **3-path dispatch
cascade** with clear priority:
```
ComputeInternal
│
├─ Early reject: softcap, softmax_precision (not supported by any CUDA kernel)
│
├─ Flash Attention (fp16/bf16, Ampere+, head_size==v_head_size, no output_qk)
│ ├─ Path 1: nonpad_kv_seqlen (opset 24 external cache) → mha_fwd_kvcache
│ ├─ Path 2: past_key + past_value (internal cache decode) → mha_fwd_kvcache
│ │ Supports: bool mask → seqlens_k, no mask → fill past_seq_len
│ └─ Path 3: no past, no mask (prompt) → mha_fwd
│ Excluded: nonpad + mask combo (no bias param), float masks on prompt
│
├─ Memory Efficient (fp16/bf16, SM75+, no past_key, bias stride aligned)
│ ├─ Path 1: nonpad_kv_seqlen (±mask) → custom_right_padding + additive bias
│ ├─ Path 2: with mask (prompt) → standard MEA with additive bias
│ └─ Path 3: no mask (prompt) → standard MEA
│
└─ Unfused Fallback (all dtypes, all mask types, MHA only)
├─ nonpad_kv_seqlen (±mask) → attention_bias with additive composition
└─ Bridges to contrib MHA's QkvToContext (GEMM+softmax+GEMM)
```
GQA head expansion is handled via `LaunchUngroup` in the MEA path. Flash
handles GQA natively via `kv_num_heads` parameter.
---
## What DOESN'T Work ❌
| Combination | Reason | TODO? |
| :--- | :--- | :--- |
| **GQA + fp32** (any) | `MEA LaunchUngroup` is fp16/bf16 only; unfused
rejects GQA | ✅ TODO added |
| **GQA + float mask + decode** | Flash rejects float masks; MEA rejects
decode; unfused rejects GQA | Known limitation |
| **Softcap on CUDA** | All 3 kernels reject | ✅ TODO added |
| **softmax_precision on CUDA** | All 3 kernels reject | ✅ TODO added |
| **output_qk** beyond `kNone`/`kQK` | Only unfused supports, and only
`kNone`/`kQK` | ✅ TODO added |
\* 3 kernels mean Flash/MEA/Unfused
### Additional Known Limitations
| Limitation | Details |
| :--- | :--- |
| **attn_mask shorter than total_seq** | Spec allows padding with -inf;
validation enforces exact match. Tracked separately. |
| **No past_present_share_buffer** | Decode always copies past→present
(memset + strided copy). Contrib GQA can share buffers for zero-copy. |
| **No XQA kernel** | Contrib GQA's specialized decode kernel requires
shared buffers. ONNX Attention's internal-cache decode is ~15-30% slower
for GQA workloads. This does not apply when using TensorScatter +
`nonpad_kv_seqlen` (opset 24), which avoids the copy overhead entirely.
|
---
## Bug Fixes
- **2D mask shape interpretation (MEA path):** Fixed incorrect
interpretation of 2D `attn_mask` as `[B, kv_seq]` — the ONNX spec
defines it as `[q_seq, total_seq]`. The old code could read
out-of-bounds when `batch_size > q_seq` (e.g., batched decode). Fix uses
MEA's native broadcast flags (`strideB=0, strideH=0`) instead of an
expansion kernel, which is both correct and more efficient.
- **4D BNSH present_key/value (prompt path):** Added D2D
`cudaMemcpyAsync` for the `!is_bsnh` case in both Flash and MEA prompt
paths. Previously, `present_key`/`present_value` outputs were silently
left uninitialized for 4D BNSH inputs. Decode path already worked via
`mha_fwd_kvcache`; unfused already worked via `QkvToContext`.
- **Cutlass FMHA bias alignment:** Added alignment checks for attention
bias strides in the FMHA launch template, mirroring checks in the kernel
to prevent runtime enforcement errors on misaligned strides.
- **SEGFAULT fix** (`group_query_attention_impl.cu`): Added missing
explicit template instantiations for `LaunchUngroup<__half>` and
`LaunchUngroup<__nv_bfloat16>`. The new `attention.cc` calls
`LaunchUngroup` cross-TU, but without explicit instantiations the symbol
was undefined at runtime → `dlopen` crash → all CUDA test binaries
failed.
- **`val >= 0` assert fix** (`attention_mask_impl.cu`): Fixed overly
strict `val > 0` asserts in Flash kernel and AttentionBias kernel —
`val=0` is valid for count semantics. Kept `val > 0` in GQA kernel with
comment explaining last-valid-index convention.
- **Negative seqlens_k from all-false mask** (`attention_mask_impl.cu`):
When mask is all-false (`seq_len=0`) and decode offset is negative,
`seq_len + seqlen_offset` produced hugely negative `seqlens_k` passed to
flash. Fixed with `max(0, seq_len + seqlen_offset)` clamp.
- **BFloat16 Flash gate fix:** Fixed `disable_flash_attention_`
constructor gate that used `!std::is_same<T, MLFloat16>::value`,
accidentally excluding BFloat16 from Flash. Changed to `sizeof(T) != 2`
to match contrib GQA/PackedMHA/PagedAttention pattern.
---
## Changes by Area
### Attention Operator and Kernel Registration
* Added support for the Attention operator in opset 24 for CUDA,
including kernel registration for `float`, `MLFloat16`, and `BFloat16`
types, and refactored opset 23 kernel registration to use versioned
kernel macros for clarity and future extensibility.
### Env Var and Kernel Options Support
* Added `disable_flash_attention_` and
`disable_memory_efficient_attention_` members read from
`AttentionKernelOptions` via `GetAttentionKernelOptions()`. The ONNX
Attention kernel now respects `ORT_DISABLE_FLASH_ATTENTION` and
`ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION` environment variables, matching
the pattern used by all contrib CUDA attention ops.
### Tensor Transposition Utilities
* Replaced 9 copy-paste `if constexpr` type-switch blocks with 2
templated helpers (`TransposeBNSHtoBSNH<T>`, `TransposeBSNHtoBNSH<T>`)
for efficiently transposing tensors between BxNxSxH and BxSxNxH formats
for `float`, `half`, and `BFloat16` types. Net -117 lines.
### Attention Mask and Sequence Length Handling
* Enhanced the mask-to-sequence length conversion kernel to support both
GQA and Flash conventions with a configurable offset, improved handling
for all-false masks, and refactored the launch function to accept the
offset parameter.
### nonpad_kv_seqlen + attn_mask Composition
* MEA path: supports both via `has_custom_right_padding` (seqlens_k) +
additive attn_bias simultaneously.
* Unfused path: converts nonpad to attention_bias, then composes mask
additively via `LaunchAddBiasInPlace`. Supports 2D `[q, t]` and 4D `[B,
1, q, t]` mask shapes.
* Flash path: excluded (no bias parameter when seqlens_k is used).
Warning logged on fallback.
### Alignment Checks for Attention Bias
* Added stricter alignment checks for attention bias strides in the
CUTLASS FMHA launch template, mirroring checks in the kernel to prevent
runtime enforcement errors.
### Helper Functions and API Improvements
* Added a new helper function for ungrouping in group query attention
(`LaunchUngroup`) and improved the output shape computation for
attention to optionally skip nonpad data validation for GPU paths.
* Refactored the CUDA Attention kernel class to split logic into
dedicated methods for FlashAttention, MemoryEfficientAttention, and
UnfusedAttention, each with ASCII dispatch diagrams documenting
supported parameter combinations.
---
## Testing
**39 C++ tests** in `attention_op_test.cc` covering opset 23/24
correctness, nonpad+mask composition (bool and float), BFloat16 Flash
guard, all-false mask regression, batch≠q_seq broadcast guard, and
nonpad edge cases.
New test files under
`onnxruntime/test/python/transformers/test_onnx_attention/`:
- **`test_mha.py`** (14 test classes, 16 test functions) — MHA path:
flash, MEA, unfused (fp32), BF16, 3D/4D inputs, bool/float masks,
past/present KV cache, causal mode, broadcast masks. Unfused tests
explicitly disable flash+MEA via env vars.
- **`test_gqa.py`** (10 test classes, 13 test functions) — GQA path:
flash, flash BF16, MEA, 3D/4D inputs, head ratios (32:8, 8:2),
bool/float masks, past/present KV cache, 4D BNSH present_kv correctness,
nonpad_kv_seqlen (Flash + MEA + CPU).
- **`test_tensorscatter_attention.py`** (5 test classes) — External KV
cache with `nonpad_kv_seqlen` (opset 24), TensorScatter-based cache
management, fp16 + fp32, with-mask variants (bool + float).
- **`common.py`** — Shared test infrastructure: ONNX model builders,
reference implementations, config classes.
---------
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>