onnxruntime
5d24c954 - [CUDA] Update Flash Attention Implementation and APIs (#26937)

Commit
8 days ago
[CUDA] Update Flash Attention Implementation and APIs (#26937) ## Summary This PR updates the Flash Attention implementation in ONNX Runtime, syncing with newer kernel sources in https://github.com/Dao-AILab/flash-attention, and extending the internal API to support additional features required for advanced caching scenarios. It also aligns specific kernels with the official implementation. ## Changes - **Flash Attention Kernels**: Updated/Added Flash Attention forward kernels and headers in `onnxruntime/contrib_ops/cuda/bert/flash_attention/`. - **API Extension**: Updated `mha_fwd` and `mha_fwd_kvcache` in `flash_api.h` and `flash_api.cc` to accept two new optional parameters: - `cache_batch_idx`: Indices to index into the KV cache (support for non-contiguous batch indices). - `leftpad_k`: Support for left-padding in the key sequence. - **Alignment & Fixes**: - **Cleanup**: Removed redundant `kInfinity` definition in `flash_fwd_kernel.h`. - **Includes**: Added missing `<core/providers/cuda/shared_inc/cuda_call.h>` in `flash_fwd_launch_template.h`. - **Integration**: Updated `group_query_attention_impl.cu` to align with the new `mha_fwd_kvcache` signature. - **Build Configuration**: Adjusted `onnxruntime_providers_cpu.cmake` to update the exclusion list for Flash Attention kernels in quick build mode. ## Implementation Details - The `run_mha_fwd` helper now checks if `cache_batch_idx` is provided alongside `k_new` to determine if the split kernel should be forced. - New parameters are propagated through the call stack to the underlying Flash Attention kernels.
Author
Parents
Loading