transformers
356dec12 - Add Multi-Token Prediction (MTP) inference support

Commit
29 days ago
Add Multi-Token Prediction (MTP) inference support Wires MTP speculative decoding into `generate()` for DeepSeek-V3 and GLM-4 MoE checkpoints that ship MTP modules (DeepSeek-V3 at `model.layers.61`, GLM-4 MoE at `model.layers.46`/`.92` — previously hidden by `_keys_to_ignore_on_load_unexpected`). **Model side** - New `num_nextn_predict_layers: int = 0` on `DeepseekV3Config` / `Glm4MoeConfig` (propagates to downstream variants). Default keeps the existing no-op behavior. - `DeepseekV3MTPLayer` / `Glm4MoeMTPLayer` modules mirror the DeepSeek-V3 spec as implemented in vLLM: `enorm` + `hnorm` RMSNorms → concat → linear `eh_proj(2H → H)` → a full decoder block → `shared_head (norm + lm_head)`. - `DeepseekV3Model` / `Glm4MoeModel` extend `self.layers` past `num_hidden_layers` with MTP modules; the base `forward` still iterates only `self.layers[: num_hidden_layers]`. MTP is reached exclusively via a new `model.forward_mtp(input_ids, previous_hidden_state, past_key_values, position_ids, mtp_depth)` helper (lazily extends the KV cache for MTP layer indices). **Generation side** - `GenerationConfig.use_mtp: bool = False` and a new `GenerationMode.MTP_DECODING` routed from `get_generation_mode` whenever the base mode is greedy or sample. - `_mtp_decoding` in `generation/utils.py`: main forward → sample `x_{t+1}` → chain K MTP depths for draft tokens → single verify forward → reuses `_speculative_sampling` for accept/reject → `past_key_values.crop`. Batch size 1, dynamic cache; leaves `_assisted_decoding` untouched. - `ContinuousBatchingManager` refuses `use_mtp=True` for now — paged-attention slot reservation + per-request accept/reject is tracked separately and will come as a follow-up. **Tests** - `tests/generation/test_mtp.py` covers: mode dispatch, greedy token-for-token parity vs plain `_sample` for K=1/2/3 on both models, `num_nextn_predict_layers=0` rejection, layer extension, base-forward equivalence when MTP layers are added, `forward_mtp` shapes, and the `generate_batch` `NotImplementedError`. All 9 MTP tests pass locally. `make style` clean. `make fix-repo` clean apart from the pre-existing `mlinter._using_rule_specs` env mismatch in `check_modeling_rules_doc.py` / `check_modeling_structure.py` that also fails on an unmodified checkout.
Author
Parents
Loading