Add RotaryEmbedding fusion for Qwen3 on-the-fly RoPE patterns (#27590)
### Description
Extend `FusionRotaryEmbeddings` to handle Qwen3's on-the-fly rotary
position embedding computation, where cos/sin values are computed from
`inv_freq` at runtime instead of being looked up from a pre-computed
cache.
This is a follow-up to #27556 (Qwen3 basic model type support). Depends
on #27556.
Part of #25083.
### Motivation and Context
Qwen3 models (ranked 4th on MTEB) compute RoPE differently from existing
supported models (Phi, LLaMA, etc.). Instead of pre-computing cos/sin
caches and looking them up via `Gather(cache, position_ids)`, Qwen3
computes them on-the-fly:
```python
freqs = inv_freq_expanded @ position_ids_expanded # MatMul
emb = torch.cat((freqs, freqs), dim=-1) # Concat
cos = emb.cos() * attention_scaling # Cos, Mul
sin = emb.sin() * attention_scaling # Sin, Mul
```
Additionally, TorchScript exports of Qwen3 insert `Cast` nodes in the
`rotate_half` pattern (from `torch.floor_divide` tracing), which the
existing path patterns don't account for.
### Changes
**`fusion_rotary_attention.py`:**
- Add Cast-tolerant `rotate_half` path patterns
(`rotate_half_x2_path_2_3`, `_2_4`, `rotate_half_x1_path_2_3`, `_2_4`)
that allow 1-2 Cast nodes between Unsqueeze and Div in the dynamic Slice
index computation
- Add `sin_path_5` / `cos_path_5` patterns matching the on-the-fly
computation: `MatMul → Transpose → Concat → Cos/Sin → Mul(scaling) →
Unsqueeze → Mul`, with optional Cast variant (the optimizer's earlier
Cast fusion pass may remove the Cast)
- Add `create_cos_sin_cache_from_on_the_fly_rope()` helper that extracts
`inv_freq` weights, computes cos/sin caches as model initializers, and
traces `position_ids` from the graph
- Handle per-layer vs shared node removal correctly (only remove
per-layer Unsqueeze/outer Mul nodes; shared MatMul/Cos/Sin nodes are
pruned automatically by the optimizer)
**`qwen3_model_generator.py`:**
- Add `include_rope=True` parameter to `create_qwen3_decoder_layer()`
- Generate full on-the-fly RoPE computation graph: `inv_freq`
initializer, `position_ids` input, MatMul/Transpose/Concat/Cos/Sin/Mul
nodes, and `rotate_half` pattern with dynamic Slice indices (including
Cast nodes from floor division)
- Apply RoPE to both Q and K paths
**`test_attention_fusion.py`:**
- Add `test_qwen3_rotary_embedding_fusion` verifying 2 RotaryEmbedding
nodes are fused along with 3 SimplifiedLayerNormalization and 1
SkipSimplifiedLayerNormalization
### Verification
- **Unit tests**: All 15 `test_attention_fusion.py` tests pass (14
existing + 1 new)
- **Real model**: Verified on Qwen3-Embedding-0.6B (28 layers): 56
RotaryEmbedding nodes fused (28 layers × 2 per layer for Q and K),
reducing total node count from 7416 → 4661 (37% reduction)
- **No regressions**: All changes are additive alternative path patterns
— existing models that use dynamic Slice indices or cache-based RoPE
never hit the new paths
- **Lint**: `lintrunner -a` clean on all modified files