[Qwen3.5] Fix GDN linear attention multi-token cached forward (#45513)
* Fix Qwen3.5 linear attention multi-token cached forward
The gated-delta-net forward only used the cached recurrent state when
`seq_len == 1`. For any multi-token forward with a populated cache (e.g.
chunked prefill continuation or speculative-decoding verification), it
fell through to `chunk_gated_delta_rule(initial_state=None)`, silently
restarting the linear layers from zero and ignoring the prefill state.
This breaks the causal-LM invariant that the logits at position `i` must
not depend on whether later tokens are batched into the same call —
position 0 of a 16-token verify forward ended up differing from the
corresponding single-token cached decode, collapsing to high-frequency
context tokens and destroying speculative-decoding correctness.
Add a `use_cached_chunk` path that, when `has_previous_state` is true
and `seq_len > 1`:
- reads the cached `conv_state` / `recurrent_state`,
- prepends the conv context onto the chunk input so the causal conv sees
the correct left-context,
- drops the prepended context from the output,
- passes the cached `recurrent_state` as `initial_state` to
`chunk_gated_delta_rule`.
The same fix propagates to `qwen3_5_moe` via the modular system.
Add a unit test that compares the first-position output of a
multi-token cached forward against the single-token cached forward on
the same token and cache. Without this fix the mismatch is 100%.
* Review feedback: unify cached-forward state flag, gate single-token/chunk locally
Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single
`use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)`
that just signals "we have cached conv/recurrent state to continue from". The split between
the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes
is now expressed locally via `seq_len == 1` checks at the three places where it actually
matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in
review.
Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and
`modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`.
* Propagate linear-attention multi-token cached-forward fix to qwen3_next and olmo_hybrid
qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the
cache was populated (chunked-prefill continuation or speculative-decoding verification), the
chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently
dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag
(no seq_len condition), with single-token vs chunk-tokens routing gated locally on
`seq_len == 1`.
olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the
caller had cached state but fed more than one token (`use_precomputed=False` under the old
`seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the
caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch
on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise.
olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no
change needed there beyond the `use_precomputed` flag semantics.
Add the same causal-LM invariance test to both suites
(`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element
mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models
pass end-to-end.
`modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via
`check_modular_conversion.py --fix_and_overwrite`.
* Review feedback: keep "prefill mode / multi-token decode" comment, label prepend block
Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode
else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the
wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible
on the same line. Tag the conv-context prepend block with a "dropped at the end of this
branch" hint so a reader knows the mirror operation exists.
Comment-only; behavior unchanged. Generated files re-emitted via
`check_modular_conversion.py --fix_and_overwrite`.