Validate B/scales/zero_points shape in MatMulNBits::PrePack
MatMulNBits::PrePack ran at session initialization and called the MLAS
pack routines using byte counts derived from the node attributes
(N, K, bits, block_size) without ever comparing those attributes to
the actual tensor Shape(). A crafted .onnx whose attributes overstate
the real B (or scales / zero_points) extent triggered a
heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData /
MlasLutGemmPack during OrtApis::CreateSession (no Run() required).
The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute()
-- after PrePack has already done the OOB read, and by then the
original B tensor is replaced with nullptr in the kernel context so
the Compute-time check never re-validates it.
Fix: at the top of PrePack, after the existing early-return guards
and before any tensor.DataRaw() read, validate the incoming
initializer's Shape() against the attribute-derived shape:
- B -> (N, k_blocks, blob_size)
- scales -> (N * k_blocks) or (N, k_blocks)
- zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
(N * k_blocks) or (N, k_blocks)
A mismatch returns INVALID_ARGUMENT so the session fails to load
rather than reading past the buffer.