Add DeepSeek V4 (#45643)
* Add DeepSeek V4 (modular)
Adds DeepSeek V4 with hybrid CSA/HCA attention, lightning indexer,
manifold-constrained hyper-connections, shared K=V MQA with grouped
low-rank output, and per-head attention sink. Includes tokenizer/auto
mappings, finegrained FP8 quantization support, and unit tests.
* Split V4 HCA / CSA caches and compressors into independent classes
No inheritance between HCA and CSA: each has its own cache (DynamicSlidingWindowLayer
subclass) and compressor (nn.Module subclass). HCA stays minimal (non-overlapping
windows, no indexer); CSA explicitly carries the overlap state + indexer. Shared
math factored into module-level helpers — no coff/overlap branching, no
_compress_rate_attr indirection. Also adds 'sliding_attention' to COMPRESSOR_CLASSES
with None so the three attention types are dispatched explicitly in one place.
* Fix tests_generate / tests_tensor_parallel CI failures
Generation tests were assuming V4 supports advanced decoding modes (assisted
generation, prompt lookup, contrastive search, static-cache compile) that the
compressor's running-window cache state can't service — its buffer / pool /
overlap fields aren't rewindable across drafts and aren't compatible with
:class:`StaticCache`. Set the right opt-out flags so generate raises a clear
error early and the corresponding tests skip cleanly:
* ``_is_stateful = True`` — gates assisted / prompt-lookup paths.
* ``_can_compile_fullgraph = False`` — gates the static-cache test (would
otherwise hand the compressor a :class:`StaticSlidingWindowLayer` with no
``update_compressor`` method).
* ``_supports_flex_attn = False`` — V4 only validates eager attention; the
compressor / indexer paths weren't checked under flex / SDPA / flash kernels.
Conversion mapping cleanup so save / load round-trips survive:
* Standardize on V3's ``apply_rotary_pos_emb_interleave`` for the partial-RoPE
rotation, with a thin V4-side wrapper that permutes the rope channels back
from the halves layout V3 leaves them in to the interleaved layout V4 was
trained with — required because V4 is shared-KV (V == K rotated), so V's
channel layout flows through ``wo_a`` / ``wo_b``.
* Restructure ``conversion_mapping.deepseek_v4`` into two passes: structural
prefix renames first (``layers.X.attn.`` → ``model.layers.X.self_attn.``),
then specific in-prefix renames on the already-prefixed HF-form keys
(``...self_attn.compressor.norm.`` → ``...self_attn.compressor.kv_norm.``).
A single-pass ordering loses information in either the forward or reverse
direction (overlapping general / specific patterns conflict).
* Move the FP8 ``.scale`` → ``.weight_scale_inv`` rename out of the V4 static
conversion list and into ``FineGrainedFP8HfQuantizer.update_weight_conversions``
so the rule is only registered when FP8 dequant is actually active. Lets
``test_reverse_loading_mapping`` skip an unrelated FP8 rule on plain saves.
Test fixes:
* Skip ``test_reverse_loading_mapping`` with a docstring spelling out why the
two-pass mapping can't satisfy that test's invariant (its Pass 2 source
patterns are HF-form by design; ``test_save_load`` exercises the actual
round-trip).
* Skip ``test_left_padding_compatibility`` — V4's compressor pre-pools
``compress_rate``-token windows before the attention mask is applied, so
left padding shifts window boundaries and folds pad tokens into pooled
KV entries (same fundamental limit as RecurrentGemma).
* Add ``model.to(torch_device)`` in the ``test_hidden_states_output`` override
so cuda inputs don't hit a cpu model.
* ``test_tiny_generate_runs`` now passes ``eos_token_id=-1`` so a freshly
initialised random model doesn't EOS-stop before max_new_tokens, making the
shape assertion deterministic.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Address PR review feedback batch (comments 2-24)
- apply_rotary_pos_emb takes one tensor + handles trailing-rope slicing internally;
rotate_half-style ernie pattern with repeat_interleave; rotary forward emits
half-sized cos/sin (no end-to-end duplication).
- Inherit DeepseekV4RotaryEmbedding from LagunaRotaryEmbedding (partial-rotary
compute_default_rope_parameters).
- Config:
* compress_rates dict keyed by layer type (BC kwargs for compress_rate_csa/hca).
* mlp_layer_types list (BC kwargs for num_hash_layers); MLPBlock dispatches via it.
* qk_rope_head_dim derived from partial_rotary_factor (BC kwarg accepted).
* Drop V3 inheritance + V3-only fields (kv_lora_rank, qk_nope_head_dim, v_head_dim,
n_group, topk_group, first_k_dense_replace, rope_interleave).
- Rename attention/compressor/indexer leaf weights to *_proj convention; add
conversion_mapping rules to load upstream wq_*/wkv/wgate/wo_* names.
- DeepseekV4MLP no longer inherits Qwen2MoeMLP — uses moe_intermediate_size.
- GroupedLinear forward simplified to MHA-style transpose pattern.
- Indexer / compressor: pool window views use -1 last dim (TP-friendly), softmax
in fp32, rope_layer_type as class attr.
- Drop dead self.compress_rate / self.qk_nope_head_dim assignments.
* Address PR review feedback batch (comments 25-42)
- DeepseekV4UnweightedRMSNorm: extracted weight-less RMSNorm class, used by
attention's per-head Q rescale + both HC modules' input rescale.
- HyperConnection.forward returns (post, comb, collapsed) — moves the stream
collapse into the mHC module instead of the DecoderLayer.
- Document the 3 in mHC scale param (pre / post / comb).
- DecoderLayer: input_ids in explicit signature (was kwargs.get).
- Comment defending the compressor mask pad against FA / SDPA backends.
- DeepseekV4Router: unified TopK + Hash routers into one class with a
select_indices hook (top-k + e_score_correction_bias vs tid2eid lookup).
- Rename buffer ``bias`` → ``e_score_correction_bias`` (cross-model standard);
add gate.bias → e_score_correction_bias rule in conversion_mapping.
- DeepseekV4Experts: use config.num_local_experts (routes through attribute_map)
so FP8 / TP integrations stay robust.
- Drop unused self.rotary_emb_compress on the model.
- Simplify DeepseekV4ForCausalLM to a bare `pass` inheriting MixtralForCausalLM.
* Fix Fp8Dequantize.reverse_op to actually re-quantize on save
reverse_op was _IdentityOp, so saving a model that had been loaded with
dequantize=True dropped the FP8 layout — saved checkpoints lost their
weight_scale_inv keys and round-trip through save_pretrained was lossy. Pair the
two ops symmetrically: Fp8Dequantize.reverse_op -> Fp8Quantize and
Fp8Quantize.reverse_op -> Fp8Dequantize.
Fp8Quantize.convert refactored to handle the per-expert save chain
(SplitModulelist emits one key per expert -> Fp8Quantize quantizes each), and to
pass non-tileable tensors through unchanged (1D norms / biases / odd 2D shapes
that were never quantized on the load side).
* Address Arthur's review batch + revisit two of vasqu's comments
- Drop the local rotate_half def, import from glm.modeling_glm (identical body).
- Iterate set(self.layer_types) in DeepseekV4RotaryEmbedding.__init__ for
consistency with the gemma3 idiom.
- DeepseekV4MLP inherits LlamaMLP (was a hand-written nn.Module). Config
attribute_map routes intermediate_size -> moe_intermediate_size and adds
mlp_bias=False, so LlamaMLP's __init__ builds the right shared-expert linears
without an override.
- DeepseekV4Experts inherits MixtralExperts (was GptOssExperts with an
__init__ + _apply_gate override that duplicated everything). MixtralExperts'
layout matches V4-Flash's; the only V4-specific bit is the swiglu_limit clamp
on gate / up before SiLU, kept inline in the overridden forward.
- Split the unified DeepseekV4Router back into DeepseekV4TopKRouter and
DeepseekV4HashRouter (Arthur preferred two explicit classes over a
conditional select_indices hook).
- Drop **_ from DeepseekV4SparseMoeBlock.forward — the layer's caller
(DeepseekV4DecoderLayer) already filters kwargs.
- DeepseekV4Model now inherits LlamaModel. super().__init__ sets up
embed_tokens / norm / rotary_emb / gradient_checkpointing; we override the
layer list, swap rotary_emb for the multi-layer-type V4 one, add hc_head, and
keep the V4-specific forward.
* Apply suggestions from code review
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Move DeepseekV4Config out of modular + simplify __post_init__
Configuration is now hand-edited in configuration_deepseek_v4.py — modular no
longer defines it, removes it from __all__, and imports it. The converter no
longer regenerates the config file (no class with Config suffix means nothing to
emit there).
__post_init__ is collapsed onto five small _resolve_* methods + a single
_apply_legacy_kwargs helper that strips the legacy V3-flavoured kwargs
(compress_rate_csa/hca, num_hash_layers, qk_rope_head_dim, compress_ratios)
into typed instance fields, so __post_init__ itself reads as a sequence of
named steps.
Also expand docs/source/en/model_doc/deepseek_v4.md with an Architecture section
(hybrid attention / mHC / MoE schedule / cache layers) cross-referenced to the
paper sections.
Type-check fix: gate the WeightConverter.operations access in
quantizer_finegrained_fp8.py with isinstance, so WeightRenaming entries pass
through untouched.
* Fix V4 TP failures: dynamic num_key_value_groups + FP8-safe GroupedLinear
V4 is shared-KV MQA (num_kv_heads = 1). With TP, q_b_proj is colwise-sharded so
the local q has num_heads / tp_size heads while kv stays replicated at one head.
The eager / sdpa / flash backends all read module.num_key_value_groups to repeat
kv up to q's head count — a fixed global value of num_attention_heads gives the
wrong (over-)expansion factor on every rank but the first. Refresh
num_key_value_groups from q.shape[1] in DeepseekV4Attention.forward, after the
local q has been built, so repeat_kv(key, num_key_value_groups) lifts the single
kv head to exactly the rank-local query head count.
DeepseekV4GroupedLinear was using a single bmm for the per-group projection.
torchao's Float8Tensor (used by tests_tensor_parallel_ci's
test_tp_generation_quantized) only fast-paths F.linear; bmm hits an mslk kernel
assertion (`bmm is not supported when mslk is not installed`). Replace the bmm
with a small per-group F.linear loop — slower for tiny configs, but cuts the
torchao dependency and the quantized-TP path now works without mslk.
* Revert GroupedLinear F.linear loop, keep bmm
The bmm was changed to F.linear because torchao's Float8Tensor doesn't fast-path
bmm without the mslk kernel. Reverting since a custom V4 FP8 path will land
later — we don't want to slow the unquantized GroupedLinear forward (~8x more
ops with n_groups=8) just to avoid one CI failure on the quantized-TP test.
* Fix V4 GroupedLinear comment with real V4-Flash / V4-Pro config values
The previous comment had every number wrong: V4-Flash uses o_groups=8 (not 16),
groups are 4096-dim (not 2048), and V4-Pro mixes to 7168-dim (not 4096).
Cross-checked against the deepseek-ai/DeepSeek-V4-Flash-Base and
deepseek-ai/DeepSeek-V4-Pro-Base configs on the Hub.
* up
Co-authored-by: Copilot <copilot@github.com>
* up
Co-authored-by: Copilot <copilot@github.com>
* small cleanup
* repo fixes
* nits
Co-authored-by: Copilot <copilot@github.com>
* more nits
* nits, small thing left to do
* update
* update DeepseekV4HCACache
* more update
* nits
* update
* update
* nits
* update
* fixes
* nits
* Fix CI: \N>1 backrefs in conversion mapping + drop irrelevant drift
Extend `WeightTransform.rename_source_key` from `\1`-only to `\1..\9`
substitution (`re.sub` over the existing match object, indexed off the
matched named group). The V4 conversion mapping uses `\2` for the inner
module path (`compressor` / `compressor.indexer`); without this fix the
literal `\2` leaks into the generated key names, producing UNEXPECTED
entries like `model.layers.{2...42}.self_attn.\2.q_b_proj.weight` and
MISSING entries on the matching real keys, and downstream the
freshly-randomised compressor weights make generation NaN out.
Existing single-`\1` callers are unaffected (`re.sub(r"\\\d", ...)` is
strictly a superset of the prior `replace(r"\1", ...)` path).
Also revert non-functional drift in maskformer, zamba2 and import_utils
to match main: line-wrap-only modeling_maskformer.py, docstring
reordering in configuration_zamba2.py, line-wrap in import_utils.py.
None of these changed behaviour; keeping them in the diff would re-run
`check_modular_conversion` on maskformer (whose own modular-regen on
main produces a doubled `MaskFormerMaskFormerDetr*` prefix that was
hand-fixed in the committed file).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Drop TP plan from V4 attention path; keep EP-only on experts
V4 attention is shared-KV MQA (`num_key_value_heads = 1`) plus a CSA / HCA
compressor branch fed by the same `kv_proj`. Both paths rely on
`repeat_kv` to broadcast the single KV head to the full attention head
count. Colwise-sharding `q_b_proj` would split the queries per rank but
leave KV replicated, so `repeat_kv(key, num_key_value_groups)` would
expand to the *global* head count while the rank-local query only has
`num_heads / tp_size` heads — the per-rank attention matmul fails with
the shape mismatch we just hit on tests_tensor_parallel_ci
(`tensor a (2) must match tensor b (4) at non-singleton dimension 1`).
The compressor branch can't be made to follow either, since its keys
come from a separate stateful path.
Easier and saner: leave the attention block fully replicated across TP
ranks and parallelise only what actually carries the parameter weight in
this model — the routed MoE experts (`moe_tp_experts` + the gate/down
shards) and the shared experts. That's where 99% of the parameters live
anyway.
Removes the four attention/MLP-attention TP rules and keeps the MoE +
shared_experts ones intact.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Address vasqu review batch + skip quantized-TP test
modular_deepseek_v4:
- Restore docstrings on apply_rotary_pos_emb + DeepseekV4UnweightedRMSNorm.
- DeepseekV4RotaryEmbedding.forward: comment defending why we don't duplicate
freqs end-to-end (interleaved vs half-split RoPE).
- Indexer: rename self.n_heads -> self.num_heads (matches Attention).
- DeepseekV4Attention: rename q_norm -> q_a_norm, q_head_norm -> q_b_norm to
match the q_a_proj / q_b_proj symmetry.
- DeepseekV4Attention: hoist `dtype = hidden_states.dtype` to the top of
DecoderLayer.forward (was duplicated across the two HC sites).
- DeepseekV4Attention: rewrite the conjugate-rope output comment (the transpose
pair is a layout fix-up; the -sin is the actual fix).
- DeepseekV4Attention: collapse `.reshape(...).view(...)` chain into one reshape
for the grouped output projection input.
- Drop _ prefix on q_residual / position_ids in HCA compressor signature.
- Compressor branch dispatch: `COMPRESSOR_CLASSES[layer_type](config) if
layer_type != "sliding_attention" else None` (one line, layer-type keyed).
- DeepseekV4Experts: extract _apply_gate hook (gpt-oss style) so grouped_mm /
batched_mm backends apply the swiglu_limit clamp + SiLU instead of bypassing.
- Local rename in store_compression_weights: bk/bg -> buffered_kv/buffered_gate.
- Local rename in DeepseekV4GroupedLinear: d_in -> hidden_dim.
- Add e_score_correction_bias to _keep_in_fp32_modules_strict.
- Flip _supports_flash_attn / _supports_flex_attn to True (gpt-oss-style).
cache_utils:
- StaticCache no longer hardcodes V4 layer-type names. Treat any layer type
whose registered DynamicCache class is a DynamicSlidingWindowLayer subclass as
sliding for static-cache purposes — keeps model-specific names in their model.
conversion_mapping:
- Add upstream `attn.q_norm` -> `self_attn.q_a_norm` rule for the q_norm rename.
finegrained_fp8: revert local-test leftovers
- Drop the JIT smoke-test in _load_deepgemm_kernel.
- Restore _first_attr(...) helper calls (were inlined to getattr locally).
- Uncomment "deepgemm" entry in FP8ExpertsInterface._global_mapping.
tests/deepseek_v4: skip test_tp_generation_quantized — torchao Float8Tensor's
bmm path needs the optional mslk kernel; a custom V4 FP8 path lands later.
* Drop model. prefix from V4 conversion mapping + tester / parity cleanup
conversion_mapping (V4):
- Drop ``model.`` prefix from every source / target pattern. The base-model
prefix is auto-added/stripped by ``convert_and_load_state_dict_in_model`` based
on ``base_model_prefix = "model"``, so the mapping only needs to operate in the
bare base-model namespace.
cache_utils:
- StaticCache treats any layer type whose registered DynamicCache class is a
``DynamicSlidingWindowLayer`` subclass as sliding for static-cache purposes,
but keeps ``chunked_attention`` on its own branch (it uses
``config.attention_chunk_size`` instead of ``config.sliding_window``). This
fixes the regression in tests/utils/test_cache_utils.py
(test_hybrid_chunked_cache / _extra_cases) and tests/models/olmo/...
(test_generate_with_static_cache) that came from routing chunked through the
sliding branch.
- Drop the inline comment on the dispatch loop and the dispatch-via-registry
comment in DynamicCache (vasqu's review request).
tests/deepseek_v4:
- Remove DeepseekV4ParityTest and _tiny_config (functional sanity checks
duplicated by the standard CausalLMModelTest suite).
- ModelTester: replace the ``kwargs.setdefault(...)`` mesh with explicit
``self.X = ...`` overrides after ``super().__init__()`` — same effect, reads
as a normal subclass override.
- Drop unused imports (DeepseekV4Config, DynamicCache, DeepseekV4HCACompressor,
DeepseekV4ForCausalLM).
docs/deepseek_v4.md / model_doc/lasr.md / auto_mappings.py: ``make fix-repo``
side-effects (release date, auto mapping ordering).
* Silence Unrecognized rope_parameters keys warning
V4's rope_parameters dict is keyed by *rope-type* labels ("main" / "compress"),
not by config.layer_types. The base PreTrainedConfig.validate_rope checks
`keys ⊆ layer_types` and falls back to wrapping the whole dict as one set of
params when the subset check fails — which then warns 'Unrecognized keys for
rope_type=default: {main, compress}'. Override validate_rope on V4Config to
iterate the rope-type-keyed sub-dicts directly.
* Disable flex_attention on V4; guard tensor-only mask pad
V4 attention concatenates compressor entries onto the KV axis *inside*
the attention block, after the model-level attention mask was built.
Flex's `BlockMask` is sized for the pre-concat KV length and has no
runtime resize, so `flex_attention(query, key, ..., block_mask=...)`
errors with `block_mask.kv_len=N but got kv_len=M`. Rebuilding the
BlockMask per-block would require teaching the compressor's variable
output count to a `mask_mod`, and the compressor already owns its own
causality bookkeeping — not worth it.
Set `_supports_flex_attn = False` (matches the existing `_supports_sdpa
= False` reasoning: torch's SDPA kernel doesn't carry V4's per-head
sink either).
Also guard the existing right-pad so it only fires when
`attention_mask` is a `torch.Tensor`. Even with flex disabled at the
class level, the pad on a non-tensor mask is structurally wrong — keep
the guard so the explicit-mask path can't crash on a `BlockMask`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Switch V4 to EP-only plan (gpt-oss style)
V4 attention can't be TP'd safely (shared-KV MQA + compressor's
single-head broadcast would mismatch rank-local query head counts under
colwise q_b_proj). The shared MLP is also tiny and not worth TP-ing.
Match gpt-oss instead: route on the gate, run the routed experts as a
grouped-GEMM kernel sharded along the expert axis, and wrap the experts
module with `moe_tp_experts` so its output is all-reduced across ranks.
Renames `base_model_tp_plan` → `base_model_ep_plan`. The `test_tp_*`
mixin tests skip on V4 same as gpt-oss; the `test_ep_*` tests run and
pass on CPU multiprocessing.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Fix V4 yarn rope validation under nested rope_parameters
The yarn / longrope / llama3 validators in
RotaryEmbeddingConfigMixin read self.rope_parameters[<key>] directly
(e.g. original_max_position_embeddings). With V4's rope-type-keyed
nesting, the top-level dict only has main / compress, so those reads
fail with KeyError when loading any V4 checkpoint that uses yarn
(e.g. DeepSeek-V4-Flash, factor=16, original_max_position_embeddings=
65536).
In V4's validate_rope override, temporarily point self.rope_parameters
at the rope-type-specific sub-dict for the duration of the validation
call, then restore it. Strictly local change — no behaviour shift for
non-yarn / non-scaled rope.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Trim V4 config-attribute allowlist from 24 → 6
The checker (`utils/check_config_attributes.py`) statically scans
`config.<attr>` reads in modeling files. Of the 24 V4 entries we'd been
allowlisting as "unused", 18 are in fact picked up by the scanner —
they were leftovers from an earlier pass before V4's modeling layout
settled. Run the checker with the V4 list emptied and only six attrs
fire as genuinely unused; keep just those, with one comment per entry
explaining what it's there for (every one of them is BC / config-compat
surface that we accept in `__init__` but the modeling code never reads):
- `attention_bias` — no biases on V4 linears.
- `n_shared_experts` — always exactly one shared MLP, count
is never iterated over.
- `norm_topk_prob` — V3 knob; V4's router always normalises.
- `num_key_value_heads` — V4 is shared-KV MQA (always 1).
- `num_nextn_predict_layers` — upstream MTP count; MTP head isn't
instantiated by transformers' V4.
- `router_jitter_noise` — Mixtral inheritance; V4 doesn't jitter.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Inline V4 config legacy-kwarg + resolve_* helpers into __post_init__
Six tiny helpers (`_apply_legacy_kwargs` + five `_resolve_*`) each did
one targeted fold and were called once, in order, from `__post_init__`.
The indirection didn't carry its own weight — the resolution sequence
reads more clearly as one function with comments on each step than as a
chain of method dispatches that you have to scroll between.
Behaviour-preserving rewrite:
- kwargs.pop the legacy V4 names into local variables instead of
stashing them on `self._legacy_*`. They're consumed inline below.
- keep the resolution order identical to before
(`PreTrainedConfig.__post_init__` first, then compress_rates →
layer_types → mlp_layer_types → partial_rotary_factor →
rope_parameters).
- `rope_parameters` re-nesting still has the explicit `{main, compress}`
re-assignment in the already-nested branch (drops leftover top-level
keys, keeps `to_dict` round-trip stable for `test_config`).
108 → 39 (config_to_json + V4 unit suite + EP/TP marker tests all
green; real V4-Flash load round-trips correctly).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Disable FlashAttention on V4: head_dim=512 exceeds the 256 cap
Verified locally on V4-Flash:
* `flash_attention_2` → RuntimeError:
FlashAttention forward only supports head dimension at most 256
* `kernels-community/vllm-flash-attn3` (FA3) → same 256 cap.
* `flash_attention_4` / `kernels-community/flash-attn4` → same kernel
family, same 256 cap.
V4-Flash and V4-Pro both ship `head_dim = 512`, so none of the FA*
backends can dispatch. Setting `_supports_flash_attn = False` makes
`set_attn_implementation` reject the request up front instead of
loading the model and exploding inside the kernel call.
Eager remains the supported attention path. SDPA and FlexAttention were
already off for the reasons in the existing comment (per-head sink term
not in SDPA; compressor concat invalidates BlockMask's kv_len).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* update
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Copilot <copilot@github.com>