onnxruntime
80739023 - Validate B/scales/zero_points shape in MatMulNBits::PrePack

Commit
3 days ago
Validate B/scales/zero_points shape in MatMulNBits::PrePack MatMulNBits::PrePack ran at session initialization and called the MLAS pack routines using byte counts derived from the node attributes (N, K, bits, block_size) without ever comparing those attributes to the actual tensor Shape(). A crafted .onnx whose attributes overstate the real B (or scales / zero_points) extent triggered a heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData / MlasLutGemmPack during OrtApis::CreateSession (no Run() required). The canonical shape check already lives in matmul_nbits_helper::CheckInputs, but is invoked only from Compute() -- after PrePack has already done the OOB read, and by then the original B tensor is replaced with nullptr in the kernel context so the Compute-time check never re-validates it. Fix: at the top of PrePack, after the existing early-return guards and before any tensor.DataRaw() read, validate the incoming initializer's Shape() against the attribute-derived shape: - B -> (N, k_blocks, blob_size) - scales -> (N * k_blocks) or (N, k_blocks) - zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else (N * k_blocks) or (N, k_blocks) A mismatch returns INVALID_ARGUMENT so the session fails to load rather than reading past the buffer.
Author
Parents
Loading