onnxruntime
e6c84b80 - [CUDA] Support head_sink in flash attention for GQA (#25432)

Commit
189 days ago
[CUDA] Support head_sink in flash attention for GQA (#25432) ### Description Update Flash Attention to support softmax sink in GQA. Changes: - [x] Update flash attention to support head_sink - [x] Add test_gqa.py to test cuda, and remove test_gqa_cuda.py. Note that the sink is treated as scaled, while the elements in QK GEMMs is not scaled. The sink value does not need scaling or softcap, and it joins softmax with those scaled or soft-capped values. There are two ways to add sink to softmax: * One way is to [patch normalize_softmax_lse](https://github.com/microsoft/onnxruntime/blob/1cf1aa786f6e7f7e6abd6fba8b8aea2e7a43092c/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h#L143-L178) to use sink to update max and sum. Pros is major change in one function; Cons is the logic is a little complex since row_max is unscaled, while row_sum is scaled. * Another way is to change softmax_rescale_o to handle the sink directly in the first block of online softmax by using an unscaled sink value. It is a robust way to keep core algorithm consistent. Cons is need change in multiple places, and it is little hard to work with softcap. This PR use the the first approach for easy integration. Note: Memory efficient attention change will be in separated PR. ### Motivation and Context https://github.com/microsoft/onnxruntime/pull/25269
Author
Parents
Loading