onnxruntime
db4b0f4e - [CPU] Improve QMoE kernel (#25822)

Commit
122 days ago
[CPU] Improve QMoE kernel (#25822) This pull request introduces several improvements and refactorings to the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime, focusing on enhanced support for FP32 mode, improved SwiGLU activation handling, and better test coverage. The most important changes are grouped below by theme. ### Operator Registration and Type Support - Added explicit registration and support for `QMoE` operator with both `MLFloat16` and `float` data types, enabling FP32 (non-quantized) mode in addition to quantized modes. This includes updates to kernel registration and schema/type constraints. [[1]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L109-R110) [[2]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L275-R277) [[3]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1467-R1467) [[4]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1548-R1548) ### SwiGLU Activation Improvements - Refactored `ApplySwiGLUActivation` to accept configurable `activation_alpha` and `activation_beta` parameters, matching CUDA behavior and allowing flexibility in activation function tuning. Also, dropped support for non-interleaved memory layouts (now not implemented). [[1]](diffhunk://#diff-4e4afb8dcdade0abe18bd8bea68b148b4090cd86d60a1b1422c049960231737dR49-R60) [[2]](diffhunk://#diff-edb344a38502bba9a0083ab98e274ec1b5b2606639a61df7be474a600a7b99d2L29-R61) [[3]](diffhunk://#diff-f85806c745243652a0336da094126687a6c0d14b19fe760abe73df1d940dc4cbL12-R13) - Now reads `activation_alpha` and `activation_beta` attributes from operator parameters, defaulting to values appropriate for SwiGLU. ### QMoE Operator Implementation Refactor - Refactored the QMoE operator to clarify separation between quantized and FP32 implementations, and restructured internal methods for better maintainability. Added template parameterization for data types and improved handling of expert weights and biases. [[1]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5R13-R35) [[2]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L38-R55) [[3]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L58-L59) ### Shape Checking and Layout - Removed legacy shape/layout support in QMoE input validation, enforcing only the new memory layout for expert weights and improving consistency and forward compatibility. ### Test and Documentation Updates - Updated unit tests for QMoE to use correct zero-point values for quantized weights (e.g., 0x88 for int4, 128 for int8), ensuring that test cases accurately reflect expected zero-output behavior for zero weights. Also clarified comments and expected outputs for SwiGLU and quantized scenarios. [[1]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1340-R1349) [[2]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1379-R1380) [[3]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1404-R1413) [[4]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1525-R1538) These changes collectively improve the flexibility, correctness, and maintainability of the QMoE operator in ONNX Runtime. Unit test result ``` sRunning test: batch_size=1, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000372 .Running test: batch_size=1, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000392 .Running test: batch_size=1, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=1, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=4, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000609 .Running test: batch_size=4, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000702 . ---------------------------------------------------------------------- Ran 9 tests in 46.754s OK (skipped=1) ``` --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Author
Parents
Loading