onnxruntime
[JS/WebGPU] GroupQueryAttention rewrite
#20946
Merged

[JS/WebGPU] GroupQueryAttention rewrite #20946

satyajandhyala
satyajandhyala Fix GroupQueryAttention to enable phi3 use case.
4e49819c
satyajandhyala Handle if key/value is undefined
17385aa0
satyajandhyala Reverted changes to attention.ts
dd6d8926
satyajandhyala Exported Split functionality.
a4401b8c
satyajandhyala Handle packed QKV by splitting the first input if key/value is null
05198c96
satyajandhyala Set isPastkvBSNH
bcf04902
satyajandhyala Clean up logic.
ec0af227
satyajandhyala Revert changes attention.ts
07410742
satyajandhyala output k/v if isPastkvBSNH not set.
477cba62
satyajandhyala set isPastkvBSNH based on pastKey dims.
09e76d39
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
983c4d37
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
b94fea49
satyajandhyala Fixed conflicts.
66116b69
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
2a57e705
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
fe8d5ffe
satyajandhyala Added headSize check
135a46ae
satyajandhyala isPastkvBSNH is always false for GroupQueryAttention
eae3c8ba
satyajandhyala Removed explict concat operation in GQA.
cb3201da
satyajandhyala Added seqLens to the inputs.
60ec8e5d
satyajandhyala Use seqLen only when pastKay is not empty
67560162
satyajandhyala Reshape key/value inputs
efd401b6
satyajandhyala WIP
1079f346
satyajandhyala tmp
01373b6a
satyajandhyala only some tests are passing
15ecb907
satyajandhyala group-query-attention-basic* test cases are passing
fc6af3a2
satyajandhyala Updated the op tests
d271d044
satyajandhyala reverted some changes
402ea297
satyajandhyala Added a new test case
ccea6242
satyajandhyala Attention and MultiheadAttentin tests are passing
6424106b
satyajandhyala some GQA tests are failing
44d1cf6b
satyajandhyala Removed total_sequence_length variable
8d859056
satyajandhyala improve readability
a9d7a98d
satyajandhyala Take n_reps in to address calculations.
bd5f4339
satyajandhyala Updated the hint
5364b307
satyajandhyala Added new testcases.
82bfde5b
satyajandhyala Added a comment
127f20d0
satyajandhyala JSONC formatting changes.
bb27e4f1
satyajandhyala Formatting changes.
c7cee09d
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
e30d26d8
satyajandhyala Reverted changes added for debugging
36075875
satyajandhyala Use ternary operator instead of if-else
4f4e8a61
satyajandhyala Convert workdgroup_id.z to absKvHeadIdx of key/value offset calculations
0a4c0aa8
satyajandhyala convert the result of the division to unsigned int before multiplying
a2bd3c9b
satyajandhyala Fixed expected output shape and data to match wasm
cdad003c
satyajandhyala Clean up
e52a0f7b
satyajandhyala Clean up 2
6ef33ae8
satyajandhyala Format
f0246ab6
satyajandhyala Ignore total_sequence_length input as it is not used
6a864f78
satyajandhyala satyajandhyala marked this pull request as ready for review 1 year ago
satyajandhyala satyajandhyala changed the title [JS/WebGPU] group query attention update [JS/WebGPU] group query attention rewrite 1 year ago
satyajandhyala Revert "Ignore total_sequence_length input as it is not used"
63850739
satyajandhyala Revert "Format"
26e551ac
satyajandhyala Revert "Clean up 2"
2f81b6bf
satyajandhyala Added GroupQueryAttentionAttr
5f43e4ab
satyajandhyala tmp
80a2b1a2
satyajandhyala tmp
e85da0a6
satyajandhyala The softmax computation for GQA is not as simple as that of Attention…
4eb7f432
satyajandhyala Format
8e382472
github-actions
github-actions dismissed these changes on 2024-10-15
satyajandhyala typo
50f5795d
satyajandhyala tmp
5a71a448
guschmue
satyajandhyala tmp
17953997
satyajandhyala Uncomment present_sequence_length
8149d92b
satyajandhyala Fix the condition for setting remaining elements to 0 in softmax kernel.
02dc1ffd
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
641f18c6
satyajandhyala typo
665660ef
satyajandhyala satyajandhyala changed the title [JS/WebGPU] group query attention rewrite [JS/WebGPU] GroupQueryAttention rewrite 1 year ago
satyajandhyala Reverted changes to hint
23e192af
satyajandhyala Support rotary embeddings
8c2caba8
satyajandhyala Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
a3aed655
satyajandhyala satyajandhyala dismissed their stale review 1 year ago
Already made the required changes.
satyajandhyala satyajandhyala added ep:WebGPU
satyajandhyala Removed parseGroupQueryAttentionAttributes.
ddb9fdba
satyajandhyala Keep the existing code.
56f0f3f0
guschmue
guschmue approved these changes on 2024-10-23
satyajandhyala satyajandhyala merged fd8ee489 into main 1 year ago
satyajandhyala satyajandhyala deleted the sajandhy/webgpu_group_query_attention_update branch 1 year ago
satyajandhyala satyajandhyala restored the head branch 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone