onnxruntime
9a4f4632 - [MLAS] Fix Lut GEMM Flakiness and Accuracy (#27216)

Commit
2 days ago
[MLAS] Fix Lut GEMM Flakiness and Accuracy (#27216) This PR resolves flakiness and accuracy issues in the `MatMulNBitsLutGemm` operator. ## Root Cause Analysis The `MatMulNBitsLutGemm` operator exhibited non-deterministic flakiness and numerical accuracy issues. This analysis covers the root causes addressed by the changes. ## Identified Root Causes ### 1. Data Race in [LutGemmPackQuantBData](https://github.com/microsoft/onnxruntime/blob/cee825d34d533ca325bfd8f8269c86133ae512e6/onnxruntime/core/mlas/lib/qlutgemm.cpp#L166-L295) - **Issue**: The weight packing loop was parallelized across output features ($N$). Since T-MAC packs multiple features into a single byte, concurrent updates to the same byte caused bit-level corruption. - **Fix**: Serialized the sub-byte accumulation phase of the weight packing process. ### 2. Thread-Safety in Global Configuration Map - **Issue**: `tmac_kernel_configs` (a static `std::unordered_map`) was accessed concurrently. Map insertions or rehashing during initialization could invalidate references held by other threads. - **Fix**: Added `std::mutex` protection and modified the parameter getter to return by value. ### 3. Tiling Dimension Mismatch and Buffer Safety - **Issue**: The orchestrator used batch size ($M$) for kernel configuration, while weights are tiled by features ($N$). Additionally, the kernel lacked clamping for partial tiles, leading to potential overruns. - **Fix**: Synchronized tiling logic by using $N$ for initialization, passing `TotalN` for parameter retrieval, and implementing explicit clamping and tail-case handling in the AVX2 kernel. ### Verification Results - `MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` passed 100 consecutive iterations. - Full MatMul2Bits suite passed all 10 tests with standard **0.15f** tolerance.
Author
Parents
Loading