onnxruntime
d9bb7602 - Enable KleidiAI for asymmetric 4-bit MatMulNBits on ARM64 (#27751)

Commit
10 days ago
Enable KleidiAI for asymmetric 4-bit MatMulNBits on ARM64 (#27751) ### Description Extends the KleidiAI-accelerated `MatMulNBits` path on ARM64 to support asymmetric 4-bit quantization (models with per-block zero points). Previously, KleidiAI was only used for symmetric quantization (`!HasZp` guard), and asymmetric models fell back to a significantly slower non-KleidiAI kernel. Since KleidiAI only provides symmetric int4 micro-kernels (`kai_matmul_clamp_f32_qai8dxp_qsi4c32p`), this PR runs KleidiAI as-if symmetric (hardcoded `rhs_zero_point=8`) and applies a float-domain zero-point correction post-GEMM: $$C_{\text{actual}} = C_{\text{symmetric}} + A_{\text{blksum}} \times B_{\text{ZpCorr}}^T$$ where: - `BZpCorr[n, blk] = scale_b[n, blk] × (8 - zp_b[n, blk])` — precomputed at weight packing time - `AFloatBlkSum[m, blk] = Σ A_float[m, blk_start..blk_end]` — computed per inference alongside A quantization Key changes: - **`UseKleidiAIBase()`**: New function that checks KleidiAI eligibility without the `!HasZp` guard. `UseKleidiAI()` now delegates to `!HasZp && UseKleidiAIBase()`, preserving symmetric behavior. - **B packing (`SQ4BitGemmPackQuantBDataAndBlkSum`)**: Computes and stores `BZpCorr` after KleidiAI packed B data when zero points are present. - **Workspace expansion**: Allocates space for `AFloatBlkSum` (M × BlockCountK floats) in the per-GEMM workspace for asymmetric models. - **`ComputeAFloatBlkSum`**: NEON-vectorized (4× unrolled `vaddq_f32`) function to compute per-block float sums of A. - **`ApplyBZpCorrection`**: NEON-vectorized correction kernel tiled 4-N-wide (`vfmaq_f32`) for L1-friendly BZpCorr reuse. - **PrePack**: Computes `BZpCorr` during the scales PrePack (not zero_points PrePack), since ORT may erase constant inputs after marking them packed. No changes to the symmetric path. No changes to x64. No changes to 8-bit quantization. ### Motivation and Context Asymmetric 4-bit quantized models (e.g., GPTQ/RTN with zero points) on ARM64 were **23–72% slower** than their symmetric counterparts because KleidiAI's `sdot`/`i8mm` micro-kernels only support symmetric RHS, forcing a fallback to a slower non-KleidiAI kernel path. This change closes most of that gap: | Model | Seq Len | Asym/Sym (before) | Asym/Sym (after) | Asym speedup | Asym latency (after) | |-------|---------|-------------------|------------------|--------------|----------------------| | Qwen 1.5B | 256 | 1.35× | 1.17× | **1.16×** | 1107.8ms | | Qwen 1.5B | 512 | 1.23× | 1.06× | **1.14×** | 2259.7ms | | Qwen 3B | 256 | 1.43× | 1.12× | **1.28×** | 2029.7ms | | Qwen 3B | 512 | 1.39× | 1.22× | **1.24×** | 4188.0ms | | Qwen 7B | 256 | 1.61× | 1.11× | **1.52×** | 3661.6ms | | Qwen 7B | 512 | 1.72× | 1.11× | **1.58×** | 7263.8ms | The remaining 6–22% asym/sym gap comes from the extra pass over A to compute float block sums — this cannot be fused into KleidiAI's sealed A-packing function and would require an upstream KleidiAI API change.
Author
Parents
Loading