onnxruntime
99e0119b - Optimize ONNX Attention KV cache with ConcatNewToPast and add release-build kernel safety (#27613)

Commit
4 days ago
Optimize ONNX Attention KV cache with ConcatNewToPast and add release-build kernel safety (#27613) This pull request introduces several optimizations and safety improvements to CUDA kernels used in attention, rotary embedding, and tensor scatter operations for ONNX Runtime's LLM support. The main focus is on reducing decode overhead, improving memory safety, and ensuring correct handling of edge cases in mask and index validation. The most important changes are grouped below by theme. ### Flash Attention & KV Cache Optimization * Replaced the previous pattern of zero-filling and strided copy for KV cache updates with a single fused kernel (`LaunchConcatNewToPastKV`), eliminating redundant memory writes and reducing decode overhead in `attention.cc`. This streamlines the cache update process and improves performance. [[1]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL313-R350) [[2]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL396-R436) * Updated the documentation and comments to clarify the new fused kernel approach and its performance benefits, as well as the handling of sequence lengths for cache and mask conversion. [[1]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL190-R192) [[2]](diffhunk://#diff-0701e4cc6d4951894ae1a60f35c1e6c0f69ba7595f896a23c8f5ed7265eab4ffL417-R446) ### Mask Validation & Handling * Improved mask validation in `attention_mask_impl.cu` and `attention_mask_impl.h` by clarifying that CUDA_KERNEL_ASSERT is only active in debug builds. In release builds, non-contiguous masks produce safe output by counting only leading True values, ensuring memory safety and correctness even with invalid masks. [[1]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49L11-R22) [[2]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49R100) [[3]](diffhunk://#diff-8aa9a15a92d7dc138346dce5de055911895d940ba2183b4ba45bd95ac0e5bfc9L32-R35) ### Rotary Embedding Improvements * Switched rotary embedding kernel dispatch to use `OrtToCudaType` for BFloat16, enabling native hardware arithmetic (`__nv_bfloat16`) on SM80+ GPUs for improved performance and correctness. [[1]](diffhunk://#diff-411fdb2010086b3a0ad9b048bb0d0fd7721a0e8d33d9ad396d709254973448c2R5) [[2]](diffhunk://#diff-411fdb2010086b3a0ad9b048bb0d0fd7721a0e8d33d9ad396d709254973448c2L67-R71) [[3]](diffhunk://#diff-b0846d38debfc56c4c9fbb52ae7a201323ec1eab36853cf3627838fce4bb98feR13) [[4]](diffhunk://#diff-b0846d38debfc56c4c9fbb52ae7a201323ec1eab36853cf3627838fce4bb98feR168-R176) * Added explicit kernel instantiation for `__nv_bfloat16` in rotary embedding implementation, ensuring proper support for native CUDA types. ### TensorScatter Safety Enhancements * Enhanced validation and memory safety for `write_indices` in tensor scatter operations by adding in-kernel clamping of invalid indices and clarifying behavior in comments. This prevents out-of-bounds writes and preserves CUDA graph compatibility. [[1]](diffhunk://#diff-d69233ff3987fe3093132a31710b6b64cc0a32140e2a5a415a2f1f0907bd22d2L75-R80) [[2]](diffhunk://#diff-1694a04b8ba9963cc06d651ec6a3be8aa9cb2bcb73c2438dc251ca8cdcb2eb41R32-R40) These changes collectively improve performance, robustness, and safety for CUDA-based LLM operations in ONNX Runtime. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Author
Parents
Loading