onnxruntime
6fbda6d8 - Add quantized KV cache support in CPU GroupQueryAttention (#28576)

Commit
1 day ago
Add quantized KV cache support in CPU GroupQueryAttention (#28576) ## Description This PR adds INT4/INT8 symmetric quantized KV cache support to the CPU GroupQueryAttention contrib operator, enabling reduced memory bandwidth during inference. The quantized path quantizes K/V values on write into the present cache and performs dequantized-GEMM (QK and SV) during attention computation, maintaining FP32 accumulation for accuracy. Note that this is baseline implementation. Further optimization (like AVX2 and Neon etc) will be in a follow up PR. ## Summary of Changes ### MLAS Quantization Kernels | File | Change | |------|--------| | `onnxruntime/core/mlas/inc/mlas_qkv_quant.h` | New public API header for INT4/INT8 KV-cache quantize, dequantize, and GEMM routines | | `onnxruntime/core/mlas/lib/qkv_quant.cpp` | Portable reference implementation of MlasKVQuantize, MlasQKGemm, MlasSVGemm, MlasKVDequantize | | `cmake/onnxruntime_mlas.cmake` | Register new source/header files in the MLAS build | ### CPU GQA Operator Changes | File | Change | |------|--------| | `onnxruntime/contrib_ops/cpu/bert/attention_common.h` | Add `StringToKVQuantizationType` helper for attribute parsing | | `onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h` | Add `ConcatQuantStateChunkGQA`, `ToMlasKVQuantType`, quantized attention base members, and `ApplyAttentionQuantized` implementation | | `onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc` | Extend kernel registration for T_CACHE/T_KV_SCALE type constraints; add quantization validation and dispatch to quantized path | | `onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h` | Minor cleanup of helper logic | ### CUDA EP Guard | File | Change | |------|--------| | `onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc` | Add build-time guard for INT4 KV cache not enabled | ### Tests | File | Change | |------|--------| | `onnxruntime/test/contrib_ops/group_query_attention_op_test.cc` | C++ unit tests for quantized KV cache (INT8/INT4, per-tensor/per-channel, with/without past) | | `onnxruntime/test/mlas/unittest/test_qkv_quant.cpp` | MLAS-level unit tests for quantize/dequantize/GEMM kernels | | `onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py` | Python integration tests validating end-to-end quantized GQA accuracy | ## Testing - Run C++ tests: `./onnxruntime_test_all --gtest_filter='*GroupQueryAttention*Quant*'` - Run MLAS tests: `./onnxruntime_test_all --gtest_filter='*QKVQuant*'` - Run Python tests: `pytest onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py -v` - All existing GQA tests continue to pass (no behavioral change for non-quantized paths) ## Motivation and Context Quantized KV caches significantly reduce memory bandwidth requirements for long-context LLM inference on CPU. The CUDA EP already supports INT4/INT8 quantized KV caches; this PR brings parity to the CPU EP. The MLAS kernels use the same packing conventions as the CUDA implementation for model portability. ## Checklist - [x] Tests added/updated - [x] No breaking changes (new optional inputs/attributes, existing behavior unchanged) - [x] CI passes
Author
Parents
Loading