onnxruntime
39d8520b - [CUDA] GQA CUDA Kernel Fusion and Performance Optimization (#26920)

Commit
109 days ago
[CUDA] GQA CUDA Kernel Fusion and Performance Optimization (#26920) ## Summary This PR significantly improves GroupQueryAttention (GQA) performance on CUDA by fusing multiple kernel launches, improving memory access patterns, and cleaning up sequence length semantics. ## Key Changes ### 1. Fused Kernels for Reduced Launch Overhead | New Kernel | Operations Fused | Kernels Saved | |------------|------------------|---------------| | `UnpackQKVWithRoPEAndAppendKV` | Unpack packed QKV + RoPE Q/K + KV cache append | 4-5 | | `ConcatNewToPastKVFused` | K append + V append (separate buffer mode) | 1 | | `ConcatKVInPlaceFused` | K append + V append (shared buffer mode) | 1 | ### 2. New `RotaryDispatcher` Template (`rotary_common.cuh`) Reusable RoPE implementation for fused kernels supporting: - `float`, `half`, `BFloat16` element types - `float2`, `float4` vector types - Interleaved and half-split rotation modes ### 3. Sequence Length Semantics Cleanup **Before:** Confusing `seqlens_k` / `seqlens_k_buff` with overloaded meanings. **After:** Clear separation: - `past_seq_lens` - offset where new tokens are appended - `total_seq_lens` - total valid tokens after append - `padded_seq_lens` - padded length for first prompt masking ### 4. FlashAttention Fast Decode Path New optimized path for token generation (`sequence_length == 1`, shared buffer): - Bypasses `GetSequenceLengths` kernel - Passes `past_seq_lens` directly to Flash Attention - Controlled by `ORT_DISABLE_FLASH_DECODE` env var ### 5. Integer Overflow Prevention All KV cache index calculations use `int64_t` to handle large `batch * heads * seq * head_size` products. ### 6. BFloat16 Vectorization Added `float4` (8 elements) vectorized path for BFloat16 in `ConcatTensorToTensor`. ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| | `ORT_DISABLE_FLASH_DECODE` | `false` | Disable fast decode optimization | | `ORT_DISABLE_FUSED_KV` | `false` | Use unfused K/V append kernels | ## Test Changes ### Improved Test Coverage Strategy Restructured `gqa_cuda_prompt_test_cases()` and `gqa_cuda_past_test_cases()` to explicitly iterate over kernel code path parameters: ```python # NEW: Primary iteration over kernel code paths for h in h_sizes_to_test: for packed in packed_opts: for rotary, rotary_interleaved in rotary_opts: for share_buffer in share_buffer_opts: # Secondary params (batch, seq, heads) rotate via modulo ``` | Mode | Before | After | |------|--------|-------| | Pipeline | 16 tests, 4/12 combos | 42 tests, 8/12 combos | | Comprehensive | 81 tests, 4/12 combos | 178 tests, 12/12 combos | ### New Test Parameters - Added `seqs = [(1, 1)]` for edge case testing - Added `heads = [(3, 1)]` for non-standard GQA ratios - Added `h_sizes = [40]` for non-power-of-2 head sizes (tests rotary skip logic) ### New Test Configurations - `share_buffer` config option (tests both buffer modes) - `has_position_ids` testing on CUDA - Padding prompt parity test - Fused vs unfused kernel parity tests (`TestFusedKernelParity`) - Decoding from empty cache test case `(1, 1)` ## Files Changed **Core:** - `group_query_attention_impl.cu` - Main implementation refactoring - `attention_kv_cache.cu` - Fused append kernels - `flash_api.cc` - Packed QKV stride handling **New:** - `rotary_common.cuh` - Reusable RoPE dispatcher **Tests:** - `test_gqa.py` - Extended test coverage ## Performance For decoding or subsequent prompt, we still use original flash attention kernel, so the performance is almost same as baseline. Here we only show the results of first prompt. Below are results of benchmark_gqa.py on H200 GPU. Note that the latency is measured from onnx model of a GQA node, so the latency includes extra cost. The kernel speed up can be larger (See profiling results below). ### prompt-sm90-Llama3-8B-b1-h32_8x128-float16 **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=float16, gpu=H200` Dense mean Q, K and V are separated inputs. Packed means Q, K and V are packed into one input. | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.470 | 0.277 | **1.70x** | 0.468 | 0.320 | **1.46x** | | 2048 | 1.001 | 0.517 | **1.94x** | 0.990 | 0.590 | **1.68x** | | 4096 | 2.691 | 1.174 | **2.29x** | 1.504 | 1.242 | **1.21x** | | 8192 | 7.780 | 2.292 | **3.39x** | 7.933 | 4.004 | **1.98x** | ### prompt-sm90-Llama3-8B-b1-h32_8x128-bfloat16 **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=bfloat16, gpu=H200` | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.477 | 0.274 | **1.74x** | 0.486 | 0.332 | **1.46x** | | 2048 | 1.078 | 0.500 | **2.16x** | 1.087 | 0.601 | **1.81x** | | 4096 | 2.633 | 1.144 | **2.30x** | 3.017 | 1.282 | **2.35x** | | 8192 | 7.933 | 2.712 | **2.93x** | 7.933 | 4.003 | **1.98x** | # Profiling Comparison (Prompt Phase) **Summary**: Switching from `flash_fwd_splitkv_kernel` to standard `flash_fwd_kernel` for the prompt phase (SeqLen=2048) results in a **~3x reduction in attention kernel latency** and a **~2x improvement in total operator latency**. ## 1. Packed QKV **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **639.3 us** | **287.0 us** | **2.23x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.10 us | `flash_fwd_kernel`<br>187.70 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.71 us | `UnpackQKVWithRoPEAndAppendKV`: 32.44 us<br>`GetSequenceLengths`: 1.63 us | *Fused ops added* | > **Note**: The Treatment implementation introduces a fused `UnpackQKVWithRoPEAndAppendKV` kernel which performs necessary pre-processing. Despite this added cost (~29 us), the massive gain from using the efficient `flash_fwd_kernel` instead of `flash_fwd_splitkv_kernel` yields a significant net speedup. ## 2. Dense (Separated QKV) **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **0.6468 ms** | **0.3226 ms** | **2.00x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.25 us | `flash_fwd_kernel`<br> 184.29 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.68 us | `RotaryEmbeddingBSNH`: 48.94 us<br>`ConcatNewToPastKVFused`: 13.04 us<br>`GetSequenceLengths`: 1.52 us | *See below* | > **Note**: Similar to the Packed case, the switch to the standard Flash Attention forward kernel drives the performance improvement. The pre-processing is handled by `RotaryEmbeddingBSNH` and `ConcatNewToPastKVFused` in the treatment.
Author
Parents
Loading