onnxruntime
eb706ed3 - Fix CUDA ONNX Attention: min_bias_align crash on SM<80 and MEA NaN for fully-masked batches (#27831)

Commit
27 days ago
Fix CUDA ONNX Attention: min_bias_align crash on SM<80 and MEA NaN for fully-masked batches (#27831) Description: ### Summary Fixes three issues in the CUDA ONNX Attention operator and improves spec compliance: 1. min_bias_align crash on SM<80: The alignment check for Memory Efficient Attention (MEA) bias used 4*sizeof(T) (bytes), but the check is against element counts. Fixed to 4 elements, matching CUTLASS kMinimumAlignment. This prevented valid MEA dispatch on SM<80. 2. MEA NaN for fully-masked batches: When nonpad_kv_seqlen=0, CUTLASS MEA computes 1/s_prime where s_prime=0, producing NaN. Added ZeroOutputForFullyMaskedBatches kernel (MEA path only) to zero output for these batches. Uses int64_t for element count to prevent overflow at large context lengths. 3. Flash rejects attn_mask for spec compliance: Flash Attention's paged KV cache produces spec-divergent present_key/present_value layout when used with attn_mask + past_key. Flash now requires attn_mask == nullptr — cases with bool mask + past_key fall to the unfused runner which handles them spec-correctly. Removed ~137 lines of dead code (ConvertMaskToSeqlensKernel, LaunchConvertMaskToFlashSeqlensK) no longer needed after this change. ### Known limitation - GQA + bool attn_mask + past_key currently has no runner (Flash rejected, unfused doesn't support GQA, MEA blocked by past_key). Tracked via TODO — PR #27851 (MEA with past_key support) will close this gap. ### Related - Issue #27885: Flash Attention bool attn_mask semantic divergence (root cause documented) - PR #27851: MEA with past_key support (will close GQA gap) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Author
Parents
Loading