onnxruntime
cee825d3 - [MLAS] Fix rotary interleaved NEON kernel (#26390)

Commit
105 days ago
[MLAS] Fix rotary interleaved NEON kernel (#26390) The logic of interleaved NEON kernel is not correct from code review: 1. **Test Code Logic:** The test code `test_rope.h` allocates the `sin` and `cos` tables based on the `interleaved` flag: ```c++ size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; std::vector<float> sin_data(table_len); std::vector<float> cos_data(table_len); ``` For the `interleaved = true` case, the test creates `sin` and `cos` tables of length `rotary_emb_dim / 2`. 2. **AVX2 (fp32) Kernel Logic (`interleaved = true`):** This kernel loads the `sin`/`cos` data using an index of `i / 2`: ```c++ float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); ``` This logic expects a `sin`/`cos` table of length `rotary_emb_dim / 2`. **Conclusion: The AVX2 (fp32) kernel is consistent with the test code.** 3. **NEON (fp16) Kernel Logic (`interleaved = true`):** This kernel loads the `sin`/`cos` data using an index of `i`: ```c++ // Enters loop with sin_val = MlasLoadFloat16x8(sin + i); //... // Inside loop, for next iteration: sin_val = MlasLoadFloat16x8(sin + i + 16); ``` This logic expects a `sin`/`cos` table of length `rotary_emb_dim`. **Conclusion: The NEON (fp16) kernel is NOT consistent with the test code.** ### Regression Test ``` cmake --build build/Linux/Release --config Release --target onnxruntime_mlas_test && ./build/Linux/Release/onnxruntime_mlas_test --gtest_filter=NeonFp16RoPE* ``` Before applying the fix, the test failed: ``` [ FAILED ] NeonFp16RoPE.ShortExecute (13 ms) onnxruntime/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp:66: Failure Value of: CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat()) Actual: false Expected: true Expected bits: 19491 (16.546875) Actual bits: 56596 (-325) @[16], rotary_emb_dim=24, interleaved=true ``` After applying the fix, test passed. ### Summary The `RopeKernel_Avx2_fp32_Impl<true>` kernel correctly aligns with the test code (and the fallback implementation) by expecting a `sin`/`cos` table of length `rotary_emb_dim / 2`. The `RopeKernel_Fp16_Impl<true>` (NEON) kernel incorrectly expects a table of length `rotary_emb_dim`. When run against the provided test, the NEON kernel will read past the end of the `sin_data` and `cos_data` vectors. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Author
Parents
Loading