onnxruntime
0fedb26c - Add LinearAttention and CausalConvState ops for Qwen3.5 (#27907)

Commit
7 days ago
Add LinearAttention and CausalConvState ops for Qwen3.5 (#27907) Adds custom CUDA and CPU kernels for linear attention and causal 1D convolution with state, enabling efficient inference of Qwen3.5 hybrid decoder models in ONNX Runtime. ### New Operators **`LinearAttention`** — Implements the GatedDeltaNet recurrent linear attention mechanism: - Fused kernel computing gated delta-rule update of a recurrent state matrix - Supports both prefill (multi-token) and decode (single-token) paths - Inputs: Q, K, V, decay (alpha), beta gating, optional initial recurrent state - Outputs: attention output, updated recurrent state - CUDA implementation with per-head parallelism; CPU implementation with Eigen **`CausalConvWithState`** — Implements causal 1D convolution with persistent state for autoregressive decoding: - Supports prefill (full convolution) and decode (state-based sliding window) - Inputs: input tensor, conv weights, optional bias, optional initial conv state - Outputs: convolution output, updated conv state ### Op Definitions - Registered in `com.microsoft` domain (opset 1) - Full shape inference and type constraints in `bert_defs.cc` ### Testing - Parity test (`test_parity_linear_attention_causal_conv.py`) validates CUDA and CPU kernels against PyTorch reference implementations from the FLA (Flash Linear Attention) library --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Author
Parents
Loading