Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNbits (#27325)
Add zero point support to dp4a 2-bit dequantization in the WebGPU
MatMulNBits kernel. Previously, the dp4a path for 2-bit quantization
used a hardcoded 256-entry LUT assuming zero_point=2, and was blocked
from running when custom zero points were provided.
1. dp4a_matmul_common.wgsl.template — Core LUT & dequantization function
Added a 1024-entry LUT (4 sections × 256 entries) when has_zero_points
is true. Each section corresponds to a zero point value (0–3),
pre-computing pack4xI8(value - zero_point) for every possible byte
input.
Added a new DequantizedFrom2BitsTo8Bits(in: u32, zero: i32) overload
that indexes the LUT as zero * 256 + byte_value.
Original 256-entry LUT and parameterless function preserved for the
!has_zero_points path.
2. dp4a_matmul.wgsl.template — Large-M tiled kernel (workgroup=256)
loadSHMB for n_bits==2: reads zero point via mm_read_zero() and passes
it to DequantizedFrom2BitsTo8Bits(b_value, zero) when has_zero_points.
LoadDequantizationTable: expanded to 4 calls (local_idx + 0/256/512/768)
to load all 1024 entries when has_zero_points.
3. dp4a_matmul_small_m.wgsl.template — Small-M kernel (workgroup=128)
LoadDequantizationTable: expanded to 8 calls to load 1024 entries when
has_zero_points.
DequantizedFrom2BitsTo8Bits calls pass zero when has_zero_points.
Bug fix: corrected off-by-one local_idx+127 → local_idx+128 in the
non-zero-point path.
4. matmul_nbits.cc — Kernel dispatch logic
Removed the guard !(has_zero_points && nbits == 2) that previously
blocked the dp4a path for 2-bit with custom zero points.
Updated comment to document the new 1024-entry LUT support.
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>