onnxruntime
cb1849d3 - Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNbits (#27325)

Commit
64 days ago
Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNbits (#27325) Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNBits kernel. Previously, the dp4a path for 2-bit quantization used a hardcoded 256-entry LUT assuming zero_point=2, and was blocked from running when custom zero points were provided. 1. dp4a_matmul_common.wgsl.template — Core LUT & dequantization function Added a 1024-entry LUT (4 sections × 256 entries) when has_zero_points is true. Each section corresponds to a zero point value (0–3), pre-computing pack4xI8(value - zero_point) for every possible byte input. Added a new DequantizedFrom2BitsTo8Bits(in: u32, zero: i32) overload that indexes the LUT as zero * 256 + byte_value. Original 256-entry LUT and parameterless function preserved for the !has_zero_points path. 2. dp4a_matmul.wgsl.template — Large-M tiled kernel (workgroup=256) loadSHMB for n_bits==2: reads zero point via mm_read_zero() and passes it to DequantizedFrom2BitsTo8Bits(b_value, zero) when has_zero_points. LoadDequantizationTable: expanded to 4 calls (local_idx + 0/256/512/768) to load all 1024 entries when has_zero_points. 3. dp4a_matmul_small_m.wgsl.template — Small-M kernel (workgroup=128) LoadDequantizationTable: expanded to 8 calls to load 1024 entries when has_zero_points. DequantizedFrom2BitsTo8Bits calls pass zero when has_zero_points. Bug fix: corrected off-by-one local_idx+127 → local_idx+128 in the non-zero-point path. 4. matmul_nbits.cc — Kernel dispatch logic Removed the guard !(has_zero_points && nbits == 2) that previously blocked the dp4a path for 2-bit with custom zero points. Updated comment to document the new 1024-entry LUT support. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Author
Parents
Loading