Fix use_cache with seq_len > 1 ( #46032) (#46084)
* Fix use_cache with seq_len > 1 ( #46032)
Add a third dispatch path to Mamba2Mixer for the case where a prior cached state exists and seq_len > 1 (chunked prefill / speculative decode verification).
Previously:
- cuda_kernels_forward: .squeeze(1) crashed or corrupted tensors for seq_len > 1 when has_previous_state was True.
- torch_forward: dt[:, 0, :] silently dropped all tokens after index 0, and previous_states was always zero-initialised ignoring the cache.
Both paths now:
- Gate the single-step kernel on seq_len == 1.
- For seq_len > 1, prepend the cached conv buffer (last K-1 entries) to the chunk input, run a full causal conv, drop the prepended prefix from the output, and pass the cached recurrent_state as initial_states to mamba_chunk_scan_combined / as previous_states in the naive SSD.
This PR follows the same pattern applied to GDN-based models in #45513, as discussed in the issues
I added test_mamba2_chunked_prefill (CPU/torch path, always runs) and test_mamba2_chunked_prefill_cuda_path (skipped without mamba-ssm).
* Refactor convolution logic to utilize initial_states for causal convolution
* Code quality
* merging the chunked-prefill elif branch, renaming
* making the flow cleaner
* Refactor of has_previous_states and others.
* adding in require_* decorator format
* adding in __init__
* Refactor Mamba2Mixer to make it close to qwen 3 implementation and improving cache usage
* Updated input handling and added separate tests for CPU and torch device execution paths
* adjust tests slightly to require kernels instead
---------
Co-authored-by: vasqu <antonprogamer@gmail.com>