onnxruntime
997c4798 - GQA unfused attention with FP32 QK accumulation (fixes #28195) (#28198)

Commit
9 days ago
GQA unfused attention with FP32 QK accumulation (fixes #28195) (#28198) ## Description Add a GQA-capable unfused CUDA attention kernel that writes Q·K^T to an FP32 scratch buffer, fixing fp16/bf16 overflow producing NaN when `head_size > 256` at `scale=1.0` (issue #28195, e.g. Gemma 4 global attention layers with `head_dim=512`). ### Motivation Gemma 4 uses `head_dim=512` for its global attention layers (`num_attention_heads=8, num_key_value_heads=4`). Flash Attention and Memory-Efficient Attention cap at `head_size=256`, so these fall through to the unfused path. The existing unfused MHA runner produces NaN because even though cuBLAS accumulates in FP32, the Q·K^T output tensor is fp16 and overflows. Additionally, the MHA unfused runner cannot handle GQA (`q_num_heads != kv_num_heads`). ### Key Changes **New kernel** (`contrib_ops/cuda/bert/gqa_unfused_attention.cu/.h`): - 3-stage pipeline: QK GEMM → softmax → AV GEMM - QK GEMM uses `CUBLAS_COMPUTE_32F` with `CUDA_R_32F` output type — raw Q·K^T scores written to FP32 scratch, eliminating fp16 overflow - Reshape-Q trick for native GQA support (no K/V head replication needed) - Softmax supports causal mask, sliding window (`local_window_size`), softcap, additive attention bias, and per-batch `seqlens_k` - Per-batch `past` calculation for correct sliding-window masking with variable-length sequences **GQA contrib op integration** (`group_query_attention.cc`, `group_query_attention_impl.cu`): - Activates when Flash/MEA/XQA are all ineligible and KV cache is not quantized - Uses `PrepareQKV` for RoPE and K/V cache management, then routes to the new kernel **ONNX Attention op integration** (`attention.cc`, `attention.h`): - New `RunGqaUnfusedAttention` path for GQA and fp16/bf16 with `head_size > 128` - Handles BSNH↔BNSH transposes, past_key concatenation, attn_mask→bias conversion, `nonpad_kv_seqlen` - Optimized: transposes BSNH K/V directly into `present_key`/`present_value` when available **`UnpackRoPEAppend` kernel** (`group_query_attention_qkv.cuh`): - Raised `MAX_HEAD_SIZE` from 256 to 512 to support Gemma 4 head dimensions **Safety improvements**: - `SafeInt<size_t>` for workspace size arithmetic (overflow protection) - `static_assert` guarding GQA transpose paths against accidental float instantiation ### Testing - 8 new Gemma 4 regression tests in `test_gqa.py`: prompt/decode × fp16/bf16, softcap, sliding window, long past sequences - 2 new Gemma 4 benchmark configs in `benchmark_gqa.py` (global + local attention) - All `TestGQARegressions` tests pass locally (12/12) ### Fixes Fixes #28195
Author
Parents
Loading