[JS/WebGPU] GroupQueryAttention rewrite #20946
Fix GroupQueryAttention to enable phi3 use case.
4e49819c
Handle if key/value is undefined
17385aa0
Reverted changes to attention.ts
dd6d8926
Exported Split functionality.
a4401b8c
Handle packed QKV by splitting the first input if key/value is null
05198c96
Set isPastkvBSNH
bcf04902
Clean up logic.
ec0af227
Revert changes attention.ts
07410742
output k/v if isPastkvBSNH not set.
477cba62
set isPastkvBSNH based on pastKey dims.
09e76d39
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
983c4d37
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
b94fea49
Fixed conflicts.
66116b69
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
2a57e705
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
fe8d5ffe
Added headSize check
135a46ae
isPastkvBSNH is always false for GroupQueryAttention
eae3c8ba
Removed explict concat operation in GQA.
cb3201da
Added seqLens to the inputs.
60ec8e5d
Use seqLen only when pastKay is not empty
67560162
Reshape key/value inputs
efd401b6
WIP
1079f346
tmp
01373b6a
only some tests are passing
15ecb907
group-query-attention-basic* test cases are passing
fc6af3a2
Updated the op tests
d271d044
reverted some changes
402ea297
Added a new test case
ccea6242
Attention and MultiheadAttentin tests are passing
6424106b
some GQA tests are failing
44d1cf6b
Removed total_sequence_length variable
8d859056
improve readability
a9d7a98d
Take n_reps in to address calculations.
bd5f4339
Updated the hint
5364b307
Added new testcases.
82bfde5b
Added a comment
127f20d0
JSONC formatting changes.
bb27e4f1
Formatting changes.
c7cee09d
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
e30d26d8
Reverted changes added for debugging
36075875
Use ternary operator instead of if-else
4f4e8a61
Convert workdgroup_id.z to absKvHeadIdx of key/value offset calculations
0a4c0aa8
convert the result of the division to unsigned int before multiplying
a2bd3c9b
Fixed expected output shape and data to match wasm
cdad003c
Clean up
e52a0f7b
Clean up 2
6ef33ae8
Format
f0246ab6
Ignore total_sequence_length input as it is not used
6a864f78
satyajandhyala
marked this pull request as ready for review 1 year ago
satyajandhyala
changed the title [JS/WebGPU] group query attention update [JS/WebGPU] group query attention rewrite 1 year ago
Revert "Ignore total_sequence_length input as it is not used"
63850739
Revert "Format"
26e551ac
Revert "Clean up 2"
2f81b6bf
Added GroupQueryAttentionAttr
5f43e4ab
tmp
80a2b1a2
tmp
e85da0a6
The softmax computation for GQA is not as simple as that of Attention…
4eb7f432
Format
8e382472
typo
50f5795d
tmp
5a71a448
tmp
17953997
Uncomment present_sequence_length
8149d92b
Fix the condition for setting remaining elements to 0 in softmax kernel.
02dc1ffd
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
641f18c6
typo
665660ef
satyajandhyala
changed the title [JS/WebGPU] group query attention rewrite [JS/WebGPU] GroupQueryAttention rewrite 1 year ago
Reverted changes to hint
23e192af
Support rotary embeddings
8c2caba8
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
a3aed655
Removed parseGroupQueryAttentionAttributes.
ddb9fdba
Keep the existing code.
56f0f3f0
guschmue
approved these changes
on 2024-10-23
satyajandhyala
deleted the sajandhy/webgpu_group_query_attention_update branch 1 year ago
Assignees
No one assigned