onnxruntime
1c577b71 - Allow FP16 math in flash attention (#24953)

Commit
239 days ago
Allow FP16 math in flash attention (#24953) ### Description This change restores back fp16 math based FlashAttention. ### Motivation and Context Earlier we noticed quality issues with deepseek-r1 attributed to overflow of qk computation when performing math in fp16 precision. https://github.com/microsoft/onnxruntime/pull/24723, addressed it by promoting math to fp32 to avoid the precision issue. However the topic remained that, these models are trained with FP8 precision how is it that inferencing runs into precision issues with FP16 math? using FP32 math also resulted in slight performance degradation. In this follow up investigation, one issue identified is that we multiply scale for gqa quite late in the computation. Scale is 0.088 for deepseek-r1. Multiplying scale upfront seems to prevent the overflow issues. For now only the prefill shaders are updated to use this approach. Pending feedback on impact across models, the generation shader can also be restored to FP16 math. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Author
Parents
Loading