Add Qwen3 model type support to Python transformer optimizer (#27556)
### Description
Add `qwen3` to the Python transformer optimizer's model type registry,
enabling graph optimization for Qwen3 models (e.g.,
Qwen3-Embedding-0.6B, ranked 4th on MTEB).
### Motivation
Fixes https://github.com/microsoft/onnxruntime/issues/25083
Running `optimum-cli export onnx --optimize O3` on Qwen3 models fails
with:
```
ValueError: Unsupported model type: qwen3
```
This PR resolves that by registering the model type and fixing a fusion
gap that blocked normalization fusions.
### Changes
**Model type registration** (`optimizer.py`):
- Add `"qwen3": (Gpt2OnnxModel, "pytorch", 0)` to `MODEL_TYPES`
- Uses `Gpt2OnnxModel` (not `BertOnnxModel`) because its
`fuse_attention()` calls `FusionRotaryAttention`, which searches on
`SkipSimplifiedLayerNormalization` anchors — needed for RMSNorm-based
models
**Fusion option defaults** (`fusion_options.py`):
- Disable `EmbedLayerNormalization` (decoder-only, no BERT-style
embedding)
- Set `AttentionMaskFormat.NoMask` (causal masking is implicit)
**SkipLayerNormalization fusion fallback** (`fusion_skiplayernorm.py`):
- When symbolic shape inference fails (common with dynamo-exported
models), the fusion previously returned early, skipping all
`SkipLayerNormalization` / `SkipSimplifiedLayerNormalization` fusions
- Now it falls through with the safe default `skip_index=1` (second Add
input is skip), since both inputs are already verified as
non-initializer dynamic tensors (lines 88-90)
- This enables `SkipSimplifiedLayerNormalization` fusion on Qwen3 models
where shape inference fails
**Test** (`test_attention_fusion.py`, `qwen3_model_generator.py`):
- Synthetic Qwen3 decoder layer graph with pre-attention RMSNorm, Q/K/V
projections, QK-Norm, simplified attention, output projection, residual
connection, and post-attention RMSNorm
- Verifies 3× `SimplifiedLayerNormalization` (pre-attn, Q-norm, K-norm)
+ 1× `SkipSimplifiedLayerNormalization` (residual + post-attn RMSNorm)
**Verified on real model**: Running the optimizer on an exported
Qwen3-Embedding-0.6B (2-layer) reduces nodes from 208 → 150 (28%
reduction). All 9 RMSNorm patterns fuse correctly: 5×
`SimplifiedLayerNormalization` + 4× `SkipSimplifiedLayerNormalization`.
**Scope note**: Full RotaryEmbedding + MultiHeadAttention fusion for
Qwen3's dynamo-exported graphs requires additional pattern matching work
(static Slice indices, on-the-fly sin/cos computation, QK-Norm in Q/K
paths, GQA expansion). That will be addressed in a follow-up PR.
### Test Plan
- [x]
`test_attention_fusion.py::TestFusion::test_qwen3_normalization_fusion`
passes
- [x] All 14 existing tests in `test_attention_fusion.py` pass (no
regressions)
- [x] All 4 tests in `test_optimizer_huggingface_bert.py` pass (bert,
distillbert, roberta, xlm_roberta — no regressions from the
SkipLayerNorm fallback change)
- [x] `lintrunner -a` clean