transformers
f53ca05d - [Qwen3.5] Fix GDN linear attention multi-token cached forward (#45513)

Commit
19 days ago
[Qwen3.5] Fix GDN linear attention multi-token cached forward (#45513) * Fix Qwen3.5 linear attention multi-token cached forward The gated-delta-net forward only used the cached recurrent state when `seq_len == 1`. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through to `chunk_gated_delta_rule(initial_state=None)`, silently restarting the linear layers from zero and ignoring the prefill state. This breaks the causal-LM invariant that the logits at position `i` must not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness. Add a `use_cached_chunk` path that, when `has_previous_state` is true and `seq_len > 1`: - reads the cached `conv_state` / `recurrent_state`, - prepends the conv context onto the chunk input so the causal conv sees the correct left-context, - drops the prepended context from the output, - passes the cached `recurrent_state` as `initial_state` to `chunk_gated_delta_rule`. The same fix propagates to `qwen3_5_moe` via the modular system. Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%. * Review feedback: unify cached-forward state flag, gate single-token/chunk locally Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single `use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)` that just signals "we have cached conv/recurrent state to continue from". The split between the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes is now expressed locally via `seq_len == 1` checks at the three places where it actually matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in review. Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and `modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`. * Propagate linear-attention multi-token cached-forward fix to qwen3_next and olmo_hybrid qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the cache was populated (chunked-prefill continuation or speculative-decoding verification), the chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag (no seq_len condition), with single-token vs chunk-tokens routing gated locally on `seq_len == 1`. olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the caller had cached state but fed more than one token (`use_precomputed=False` under the old `seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise. olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no change needed there beyond the `use_precomputed` flag semantics. Add the same causal-LM invariance test to both suites (`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models pass end-to-end. `modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`. * Review feedback: keep "prefill mode / multi-token decode" comment, label prepend block Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible on the same line. Tag the conv-context prepend block with a "dropped at the end of this branch" hint so a reader knows the mirror operation exists. Comment-only; behavior unchanged. Generated files re-emitted via `check_modular_conversion.py --fix_and_overwrite`.
Author
Parents
Loading