onnxruntime
1cfda524 - Support Attention(24)-CUDA and disjoint from contrib op (#27542)

Commit
89 days ago
Support Attention(24)-CUDA and disjoint from contrib op (#27542) Fix #27485 Reference https://github.com/microsoft/onnxruntime/pull/27486 ## Summary Implements CUDA kernel support for the **ONNX Attention op (opset 23 and 24)** using a thin-dispatcher architecture that directly calls Flash Attention, Memory Efficient Attention (cutlass FMHA), and Unfused (GEMM+softmax+GEMM) kernels — bypassing the contrib MHA/GQA dispatch layers entirely. This covers both Multi-Head Attention (MHA) and Grouped-Query Attention (GQA) variants, with support for internal KV cache (past/present), external KV cache (`nonpad_kv_seqlen`), boolean and float attention masks, combined `nonpad_kv_seqlen + attn_mask` composition, and 3D (BSNH) / 4D (BNSH) input formats. The kernel respects `ORT_DISABLE_FLASH_ATTENTION` and `ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION` environment variables, consistent with contrib MHA/GQA behavior. --- ## Architecture: Thin-Dispatcher Design The core kernel (`attention.cc`) implements a **3-path dispatch cascade** with clear priority: ``` ComputeInternal │ ├─ Early reject: softcap, softmax_precision (not supported by any CUDA kernel) │ ├─ Flash Attention (fp16/bf16, Ampere+, head_size==v_head_size, no output_qk) │ ├─ Path 1: nonpad_kv_seqlen (opset 24 external cache) → mha_fwd_kvcache │ ├─ Path 2: past_key + past_value (internal cache decode) → mha_fwd_kvcache │ │ Supports: bool mask → seqlens_k, no mask → fill past_seq_len │ └─ Path 3: no past, no mask (prompt) → mha_fwd │ Excluded: nonpad + mask combo (no bias param), float masks on prompt │ ├─ Memory Efficient (fp16/bf16, SM75+, no past_key, bias stride aligned) │ ├─ Path 1: nonpad_kv_seqlen (±mask) → custom_right_padding + additive bias │ ├─ Path 2: with mask (prompt) → standard MEA with additive bias │ └─ Path 3: no mask (prompt) → standard MEA │ └─ Unfused Fallback (all dtypes, all mask types, MHA only) ├─ nonpad_kv_seqlen (±mask) → attention_bias with additive composition └─ Bridges to contrib MHA's QkvToContext (GEMM+softmax+GEMM) ``` GQA head expansion is handled via `LaunchUngroup` in the MEA path. Flash handles GQA natively via `kv_num_heads` parameter. --- ## What DOESN'T Work ❌ | Combination | Reason | TODO? | | :--- | :--- | :--- | | **GQA + fp32** (any) | `MEA LaunchUngroup` is fp16/bf16 only; unfused rejects GQA | ✅ TODO added | | **GQA + float mask + decode** | Flash rejects float masks; MEA rejects decode; unfused rejects GQA | Known limitation | | **Softcap on CUDA** | All 3 kernels reject | ✅ TODO added | | **softmax_precision on CUDA** | All 3 kernels reject | ✅ TODO added | | **output_qk** beyond `kNone`/`kQK` | Only unfused supports, and only `kNone`/`kQK` | ✅ TODO added | \* 3 kernels mean Flash/MEA/Unfused ### Additional Known Limitations | Limitation | Details | | :--- | :--- | | **attn_mask shorter than total_seq** | Spec allows padding with -inf; validation enforces exact match. Tracked separately. | | **No past_present_share_buffer** | Decode always copies past→present (memset + strided copy). Contrib GQA can share buffers for zero-copy. | | **No XQA kernel** | Contrib GQA's specialized decode kernel requires shared buffers. ONNX Attention's internal-cache decode is ~15-30% slower for GQA workloads. This does not apply when using TensorScatter + `nonpad_kv_seqlen` (opset 24), which avoids the copy overhead entirely. | --- ## Bug Fixes - **2D mask shape interpretation (MEA path):** Fixed incorrect interpretation of 2D `attn_mask` as `[B, kv_seq]` — the ONNX spec defines it as `[q_seq, total_seq]`. The old code could read out-of-bounds when `batch_size > q_seq` (e.g., batched decode). Fix uses MEA's native broadcast flags (`strideB=0, strideH=0`) instead of an expansion kernel, which is both correct and more efficient. - **4D BNSH present_key/value (prompt path):** Added D2D `cudaMemcpyAsync` for the `!is_bsnh` case in both Flash and MEA prompt paths. Previously, `present_key`/`present_value` outputs were silently left uninitialized for 4D BNSH inputs. Decode path already worked via `mha_fwd_kvcache`; unfused already worked via `QkvToContext`. - **Cutlass FMHA bias alignment:** Added alignment checks for attention bias strides in the FMHA launch template, mirroring checks in the kernel to prevent runtime enforcement errors on misaligned strides. - **SEGFAULT fix** (`group_query_attention_impl.cu`): Added missing explicit template instantiations for `LaunchUngroup<__half>` and `LaunchUngroup<__nv_bfloat16>`. The new `attention.cc` calls `LaunchUngroup` cross-TU, but without explicit instantiations the symbol was undefined at runtime → `dlopen` crash → all CUDA test binaries failed. - **`val >= 0` assert fix** (`attention_mask_impl.cu`): Fixed overly strict `val > 0` asserts in Flash kernel and AttentionBias kernel — `val=0` is valid for count semantics. Kept `val > 0` in GQA kernel with comment explaining last-valid-index convention. - **Negative seqlens_k from all-false mask** (`attention_mask_impl.cu`): When mask is all-false (`seq_len=0`) and decode offset is negative, `seq_len + seqlen_offset` produced hugely negative `seqlens_k` passed to flash. Fixed with `max(0, seq_len + seqlen_offset)` clamp. - **BFloat16 Flash gate fix:** Fixed `disable_flash_attention_` constructor gate that used `!std::is_same<T, MLFloat16>::value`, accidentally excluding BFloat16 from Flash. Changed to `sizeof(T) != 2` to match contrib GQA/PackedMHA/PagedAttention pattern. --- ## Changes by Area ### Attention Operator and Kernel Registration * Added support for the Attention operator in opset 24 for CUDA, including kernel registration for `float`, `MLFloat16`, and `BFloat16` types, and refactored opset 23 kernel registration to use versioned kernel macros for clarity and future extensibility. ### Env Var and Kernel Options Support * Added `disable_flash_attention_` and `disable_memory_efficient_attention_` members read from `AttentionKernelOptions` via `GetAttentionKernelOptions()`. The ONNX Attention kernel now respects `ORT_DISABLE_FLASH_ATTENTION` and `ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION` environment variables, matching the pattern used by all contrib CUDA attention ops. ### Tensor Transposition Utilities * Replaced 9 copy-paste `if constexpr` type-switch blocks with 2 templated helpers (`TransposeBNSHtoBSNH<T>`, `TransposeBSNHtoBNSH<T>`) for efficiently transposing tensors between BxNxSxH and BxSxNxH formats for `float`, `half`, and `BFloat16` types. Net -117 lines. ### Attention Mask and Sequence Length Handling * Enhanced the mask-to-sequence length conversion kernel to support both GQA and Flash conventions with a configurable offset, improved handling for all-false masks, and refactored the launch function to accept the offset parameter. ### nonpad_kv_seqlen + attn_mask Composition * MEA path: supports both via `has_custom_right_padding` (seqlens_k) + additive attn_bias simultaneously. * Unfused path: converts nonpad to attention_bias, then composes mask additively via `LaunchAddBiasInPlace`. Supports 2D `[q, t]` and 4D `[B, 1, q, t]` mask shapes. * Flash path: excluded (no bias parameter when seqlens_k is used). Warning logged on fallback. ### Alignment Checks for Attention Bias * Added stricter alignment checks for attention bias strides in the CUTLASS FMHA launch template, mirroring checks in the kernel to prevent runtime enforcement errors. ### Helper Functions and API Improvements * Added a new helper function for ungrouping in group query attention (`LaunchUngroup`) and improved the output shape computation for attention to optionally skip nonpad data validation for GPU paths. * Refactored the CUDA Attention kernel class to split logic into dedicated methods for FlashAttention, MemoryEfficientAttention, and UnfusedAttention, each with ASCII dispatch diagrams documenting supported parameter combinations. --- ## Testing **39 C++ tests** in `attention_op_test.cc` covering opset 23/24 correctness, nonpad+mask composition (bool and float), BFloat16 Flash guard, all-false mask regression, batch≠q_seq broadcast guard, and nonpad edge cases. New test files under `onnxruntime/test/python/transformers/test_onnx_attention/`: - **`test_mha.py`** (14 test classes, 16 test functions) — MHA path: flash, MEA, unfused (fp32), BF16, 3D/4D inputs, bool/float masks, past/present KV cache, causal mode, broadcast masks. Unfused tests explicitly disable flash+MEA via env vars. - **`test_gqa.py`** (10 test classes, 13 test functions) — GQA path: flash, flash BF16, MEA, 3D/4D inputs, head ratios (32:8, 8:2), bool/float masks, past/present KV cache, 4D BNSH present_kv correctness, nonpad_kv_seqlen (Flash + MEA + CPU). - **`test_tensorscatter_attention.py`** (5 test classes) — External KV cache with `nonpad_kv_seqlen` (opset 24), TensorScatter-based cache management, fp16 + fp32, with-mask variants (bool + float). - **`common.py`** — Shared test infrastructure: ONNX model builders, reference implementations, config classes. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Author
Parents
Loading