transformers
892be11b - Fix use_cache with seq_len > 1 ( #46032) (#46084)

Commit
8 days ago
Fix use_cache with seq_len > 1 ( #46032) (#46084) * Fix use_cache with seq_len > 1 ( #46032) Add a third dispatch path to Mamba2Mixer for the case where a prior cached state exists and seq_len > 1 (chunked prefill / speculative decode verification). Previously: - cuda_kernels_forward: .squeeze(1) crashed or corrupted tensors for seq_len > 1 when has_previous_state was True. - torch_forward: dt[:, 0, :] silently dropped all tokens after index 0, and previous_states was always zero-initialised ignoring the cache. Both paths now: - Gate the single-step kernel on seq_len == 1. - For seq_len > 1, prepend the cached conv buffer (last K-1 entries) to the chunk input, run a full causal conv, drop the prepended prefix from the output, and pass the cached recurrent_state as initial_states to mamba_chunk_scan_combined / as previous_states in the naive SSD. This PR follows the same pattern applied to GDN-based models in #45513, as discussed in the issues I added test_mamba2_chunked_prefill (CPU/torch path, always runs) and test_mamba2_chunked_prefill_cuda_path (skipped without mamba-ssm). * Refactor convolution logic to utilize initial_states for causal convolution * Code quality * merging the chunked-prefill elif branch, renaming * making the flow cleaner * Refactor of has_previous_states and others. * adding in require_* decorator format * adding in __init__ * Refactor Mamba2Mixer to make it close to qwen 3 implementation and improving cache usage * Updated input handling and added separate tests for CPU and torch device execution paths * adjust tests slightly to require kernels instead --------- Co-authored-by: vasqu <antonprogamer@gmail.com>
Author
Parents
Loading