onnxruntime
227c3d5e - Add RotaryEmbedding fusion for Qwen3 on-the-fly RoPE patterns (#27590)

Commit
1 day ago
Add RotaryEmbedding fusion for Qwen3 on-the-fly RoPE patterns (#27590) ### Description Extend `FusionRotaryEmbeddings` to handle Qwen3's on-the-fly rotary position embedding computation, where cos/sin values are computed from `inv_freq` at runtime instead of being looked up from a pre-computed cache. This is a follow-up to #27556 (Qwen3 basic model type support). Depends on #27556. Part of #25083. ### Motivation and Context Qwen3 models (ranked 4th on MTEB) compute RoPE differently from existing supported models (Phi, LLaMA, etc.). Instead of pre-computing cos/sin caches and looking them up via `Gather(cache, position_ids)`, Qwen3 computes them on-the-fly: ```python freqs = inv_freq_expanded @ position_ids_expanded # MatMul emb = torch.cat((freqs, freqs), dim=-1) # Concat cos = emb.cos() * attention_scaling # Cos, Mul sin = emb.sin() * attention_scaling # Sin, Mul ``` Additionally, TorchScript exports of Qwen3 insert `Cast` nodes in the `rotate_half` pattern (from `torch.floor_divide` tracing), which the existing path patterns don't account for. ### Changes **`fusion_rotary_attention.py`:** - Add Cast-tolerant `rotate_half` path patterns (`rotate_half_x2_path_2_3`, `_2_4`, `rotate_half_x1_path_2_3`, `_2_4`) that allow 1-2 Cast nodes between Unsqueeze and Div in the dynamic Slice index computation - Add `sin_path_5` / `cos_path_5` patterns matching the on-the-fly computation: `MatMul → Transpose → Concat → Cos/Sin → Mul(scaling) → Unsqueeze → Mul`, with optional Cast variant (the optimizer's earlier Cast fusion pass may remove the Cast) - Add `create_cos_sin_cache_from_on_the_fly_rope()` helper that extracts `inv_freq` weights, computes cos/sin caches as model initializers, and traces `position_ids` from the graph - Handle per-layer vs shared node removal correctly (only remove per-layer Unsqueeze/outer Mul nodes; shared MatMul/Cos/Sin nodes are pruned automatically by the optimizer) **`qwen3_model_generator.py`:** - Add `include_rope=True` parameter to `create_qwen3_decoder_layer()` - Generate full on-the-fly RoPE computation graph: `inv_freq` initializer, `position_ids` input, MatMul/Transpose/Concat/Cos/Sin/Mul nodes, and `rotate_half` pattern with dynamic Slice indices (including Cast nodes from floor division) - Apply RoPE to both Q and K paths **`test_attention_fusion.py`:** - Add `test_qwen3_rotary_embedding_fusion` verifying 2 RotaryEmbedding nodes are fused along with 3 SimplifiedLayerNormalization and 1 SkipSimplifiedLayerNormalization ### Verification - **Unit tests**: All 15 `test_attention_fusion.py` tests pass (14 existing + 1 new) - **Real model**: Verified on Qwen3-Embedding-0.6B (28 layers): 56 RotaryEmbedding nodes fused (28 layers × 2 per layer for Q and K), reducing total node count from 7416 → 4661 (37% reduction) - **No regressions**: All changes are additive alternative path patterns — existing models that use dynamic Slice indices or cache-based RoPE never hit the new paths - **Lint**: `lintrunner -a` clean on all modified files
Author
Parents
Loading