onnxruntime
ece097dc - Add DiT attention fusion for F5-TTS and diffusion transformer models (#27999)

Commit
2 days ago
Add DiT attention fusion for F5-TTS and diffusion transformer models (#27999) ## Summary - Add `FusionMultiHeadAttentionDiT` to recognize DiT-style attention patterns (F5-TTS, etc.) and fuse them into `MultiHeadAttention`, enabling Flash Attention dispatch. - Register the new fusion as a second pass in `MmditOnnxModel.fuse_multi_head_attention()`, alongside the existing MMDit fusion for SD3/Flux. - Add test model generator and three test cases covering FP32, FP16 cast, and custom-scale variants. ## Motivation Fixes #27983 DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g., after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g., 100.0) is applied via `Mul` before `Softmax`. Optional `Cast` nodes (FP16↔FP32) may wrap `Softmax` for mixed-precision inference. The existing MMDit fusion (for SD3/Flux) expects a specific `Mul→Sqrt→Div→Sqrt→Cast→Slice→Shape` scaling path and does not match the simpler `Mul(scalar_constant)` pattern, so the attention is never fused and Flash Attention is never dispatched. This causes ~44 extra Cast ops per inference and ~200ms overhead per forward pass. ## Changes **New files:** - `onnxruntime/python/tools/transformers/fusion_mha_dit.py` — Core fusion class that matches the pattern: ``` MatMul(Q, K^T) → [Cast FP16→FP32] → Mul(scale) → Softmax → [Cast FP32→FP16] → MatMul(attn, V) → Transpose(0,2,1,3) → Reshape → output ``` and replaces it with a single `MultiHeadAttention` op (with `scale` attribute). - `onnxruntime/test/python/transformers/dit_model_generator.py` — Synthetic ONNX graph generators for testing. **Modified files:** - `onnxruntime/python/tools/transformers/onnx_model_mmdit.py` — Register `FusionMultiHeadAttentionDiT` as a second fusion pass after the existing MMDit fusion. - `onnxruntime/test/python/transformers/test_attention_fusion.py` — Three new test cases: - `test_dit_attention_fusion` — FP32 with K pre-transpose, scale=100.0 - `test_dit_attention_fusion_with_fp16_casts` — FP16 Cast nodes around Softmax - `test_dit_attention_fusion_custom_scale` — Standard 1/√d_k scale ## Test Plan - All three new DiT fusion tests pass, verifying: - Exactly 1 `MultiHeadAttention` node is produced - `num_heads` attribute is correctly detected from upstream Reshape shapes - `scale` attribute matches the original scalar constant - No `Softmax` nodes remain after fusion - Existing attention fusion tests remain unaffected - `ruff check`, `ruff format`, and `lintrunner -a` pass clean
Author
Parents
Loading