onnxruntime
19c9efc4 - [CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention (#27321)

Commit
91 days ago
[CUDA] Support FP8 (E4M3) KV Cache for Group Query Attention (#27321) # Support FP8 (E4M3) KV Cache for Group Query Attention ## Description This PR adds FP8 E4M3 quantized KV cache support for the Group Query Attention (GQA) operator on CUDA, complementing the existing INT8 and INT4 quantization paths. FP8 KV caches reduce memory bandwidth requirements during inference while maintaining higher numerical precision than INT8 for the same storage footprint. ### Motivation FP8 (E4M3) format preserves floating-point semantics with a wider dynamic range than INT8 (±448 vs ±128), making it well-suited for KV cache compression in LLM inference. This is especially beneficial on Ada Lovelace (SM89+) GPUs which have native FP8 hardware support. ## Changes ### Build System - **cmake/CMakeLists.txt**: Added `onnxruntime_USE_FP8_KV_CACHE` build option (ON by default) with `USE_FP8_KV_CACHE` compiler flag. Also added build info strings for `fp8-kv-cache`, `dump-tensor`, and `dump-node` flags. ### Operator Schema - **bert_defs.cc**: Fixed shape inference for `present_key`/`present_value` when `total_sequence_length` input is provided and past_key has 0 length. The previous code could propagate a fixed dimension that later caused "Error merging shape info" warnings. ### Kernel Registration - **group_query_attention.cc**: Registered `GroupQueryAttention<MLFloat16, Float8E4M3FN>` and `<BFloat16, Float8E4M3FN>` kernel variants. Added FP8 XQA support gating (requires SM89+) and correct `XqaQuantType` mapping. - **cuda_contrib_kernels.cc**: Added FP8 kernel class declarations and `BuildKernelCreateInfo` entries. ### Core GQA Implementation - **group_query_attention_impl.cu**: Added template instantiations for `<half, __nv_fp8_e4m3>` and `<__nv_bfloat16, __nv_fp8_e4m3>`. Updated `FlashAttentionAndQuantizeKV` to dispatch to FP8 quantization kernels via `constexpr` type check. Wrapped INT4 instantiations in `#ifdef USE_INT4_KV_CACHE`. ### Quantization / Dequantization Kernels - **group_query_attention_qdq.cuh**: Added FP8 E4M3 paths in both `DequantizeKernel` and `QuantizeKernel` using `constexpr` type dispatch on `T_QUANT`. FP8 values are clamped to ±448 before conversion. ### Fused Unpack+RoPE+Append Kernel - **group_query_attention_qkv.cuh**: Refactored `LaunchUnpackRoPEAppend` to be templated on both `T` (query type) and `U` (cache type), replacing the runtime `bit_width` parameter with compile-time type-based dispatching. Added FP8 quantization path in the `UnpackRoPEAppend` kernel using `__nv_fp8_e4m3` type. Fixed cache pointer arithmetic to use byte-level addressing. ### XQA Kernel Integration - **mha.h**: Changed `InputElem` from `half` to `__nv_fp8_e4m3` when `CACHE_ELEM_ENUM == 2` (FP8). - **xqa_loader_fp16_impl.cuh / xqa_loader_bf16_impl.cuh**: Added extern declarations and dispatch logic for FP8 kernels (`LaunchXQAFp8Kernel` / `LaunchXQAFp8KernelBF16`). - **xqa_loader_fp16_fp8_impl.cuh / xqa_loader_bf16_fp8_impl.cuh** [NEW]: FP8 XQA kernel instantiation files with group sizes 4, 8, 16, 32. - **xqa_loader_{fp16,bf16}_fp8_{64,128,256}.cu** [NEW]: Per-head-size compilation units for FP8 XQA kernels. ### Python Tooling - **io_binding_helper.py**: Extended `TypeHelper` with comprehensive data type coverage: added FP8 (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz), int4/uint4, double, int16/uint16, uint32/uint64, complex64/complex128, and string mappings across all conversion methods. ### Tests & Benchmarks - **test_gqa.py**: Added `test_gqa_fp8_kv_cache`, `test_gqa_fp8_prompt`, and `test_gqa_fp8_fallback_unsupported_head_size` test cases. Extended quantized test matrix to include FP8. Added FP8-specific tolerance values. - **gqa_test_helper.py**: Added FP8 cache type handling in `parity_check_gqa_past` and `parity_check_gqa_prompt` for proper tensor creation and dequantization comparison. - **benchmark_gqa.py**: Added FP8 benchmark support with `--fp8` flag. ## Testing - Unit tests: `test_gqa_fp8_kv_cache`, `test_gqa_fp8_prompt`, `test_gqa_fp8_fallback_unsupported_head_size` - Quantized test matrix expanded with FP8 variants (PER_TENSOR, PER_CHANNEL, shared/separate scales) - Benchmark: `benchmark_gqa.py --fp8` ## Requirements - CUDA GPU with SM89+ (Ada Lovelace / RTX 4000 series or newer) for FP8 support - Build with `onnxruntime_USE_FP8_KV_CACHE=ON` (default)
Author
Parents
Loading