Add DiT attention fusion for F5-TTS and diffusion transformer models (#27999)
## Summary
- Add `FusionMultiHeadAttentionDiT` to recognize DiT-style attention
patterns (F5-TTS, etc.) and fuse them into `MultiHeadAttention`,
enabling Flash Attention dispatch.
- Register the new fusion as a second pass in
`MmditOnnxModel.fuse_multi_head_attention()`, alongside the existing
MMDit fusion for SD3/Flux.
- Add test model generator and three test cases covering FP32, FP16
cast, and custom-scale variants.
## Motivation
Fixes #27983
DiT models like F5-TTS use an attention pattern where Q, K, V are
pre-computed (e.g., after RoPE) in BNSH format, K is pre-transposed to
BNHS, and a custom scalar scale (e.g., 100.0) is applied via `Mul`
before `Softmax`. Optional `Cast` nodes (FP16↔FP32) may wrap `Softmax`
for mixed-precision inference.
The existing MMDit fusion (for SD3/Flux) expects a specific
`Mul→Sqrt→Div→Sqrt→Cast→Slice→Shape` scaling path and does not match the
simpler `Mul(scalar_constant)` pattern, so the attention is never fused
and Flash Attention is never dispatched. This causes ~44 extra Cast ops
per inference and ~200ms overhead per forward pass.
## Changes
**New files:**
- `onnxruntime/python/tools/transformers/fusion_mha_dit.py` — Core
fusion class that matches the pattern:
```
MatMul(Q, K^T) → [Cast FP16→FP32] → Mul(scale) → Softmax → [Cast
FP32→FP16] → MatMul(attn, V)
→ Transpose(0,2,1,3) → Reshape → output
```
and replaces it with a single `MultiHeadAttention` op (with `scale`
attribute).
- `onnxruntime/test/python/transformers/dit_model_generator.py` —
Synthetic ONNX graph generators for testing.
**Modified files:**
- `onnxruntime/python/tools/transformers/onnx_model_mmdit.py` — Register
`FusionMultiHeadAttentionDiT` as a second fusion pass after the existing
MMDit fusion.
- `onnxruntime/test/python/transformers/test_attention_fusion.py` —
Three new test cases:
- `test_dit_attention_fusion` — FP32 with K pre-transpose, scale=100.0
- `test_dit_attention_fusion_with_fp16_casts` — FP16 Cast nodes around
Softmax
- `test_dit_attention_fusion_custom_scale` — Standard 1/√d_k scale
## Test Plan
- All three new DiT fusion tests pass, verifying:
- Exactly 1 `MultiHeadAttention` node is produced
- `num_heads` attribute is correctly detected from upstream Reshape
shapes
- `scale` attribute matches the original scalar constant
- No `Softmax` nodes remain after fusion
- Existing attention fusion tests remain unaffected
- `ruff check`, `ruff format`, and `lintrunner -a` pass clean