onnxruntime
f7fd3b52 - [webgpu] Register GQA based on graph capture (#26384)

Commit
88 days ago
[webgpu] Register GQA based on graph capture (#26384) This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT. The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models. In this PR, we still get total sequence length from `seqlen_k` tensor not `total_seqlen_tensor` tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use `total_seqlen_tensor` instead of `seqlen_k` when graph capture enabled.
Author
Parents
Loading