[CUDA] GroupQueryAttention operator using FlashAttention #17674
add flash attention v2
f39854de
remove bfloat16 kernels
aa1efa88
namespace
1fe60ae4
remove commented code
6cb6eba2
remove backward code
893b77ec
path
1195d3f6
flash v2 api and api call
4e59d957
api build fixes
b0f7f47d
builds succesfully
0e79ec54
some fixes
2f61ac2d
use varlen to run Attention_Mask1D_FP16_B2 test
372141db
Packed MHA cleanup
95a7e358
flash attention runs
52553a77
update
018ecb6a
pre clean
632451eb
clean PMHA
98a4e5d0
packed mha flash
12f8db86
remove extraneous changes
91c1ab8f
reviewed changes
2fc7d86d
reviewed changes
7cf03591
reviewed changes
f6554a13
reviewed changes
7a124878
reviewed changes
33a43aa7
namespace and USE_FLASH_ATTENTION flag
24524ccb
more compile flags flash
d9002f70
lint
3be4af27
gcc warnings in template
3798446f
address tianlei comments
666d3bf3
refactoring
cad2901a
workspace as buffer. copyright and cgmanifest.
fdb189bd
undo cgmanifest
14d7db2e
enable flash attention in MultiHeadAttention op
aba9cebf
fix attention test error in A100
a06da2eb
namespace from flash to onnxruntime::flash
ab015e63
undo cgmanifest
866414cf
add unit test
6d682ab0
enable flash attention in Attention op
300019e3
Merge branch 'main' into flash_v2_packed_mha
7e50d9a2
set proper nvcc threads to avoid OOM
7481ec34
--nvcc_threads=1 in build_cuda_c_api_package.sh
2823476c
test script segfaults
c8801479
pass cuda device prop to flash attention
d2a040b3
add requirements for test_flash_attn.py
8b92a9b1
remove nvcc_threads logic
a0393e2c
flash attn test, pmha works, mha crashes
0830925a
check head size for efficient attention
1c967391
lint except lambda assignment
6d8e43d2
lint fix
4ceca3ba
line length < 120
3bfa3b55
flash v2 update
bdb17d5c
formatting
bfee28ef
flash benchmark script
a54a7b98
merge with main
b064a010
Update c-api-noopenmp-packaging-pipelines.yml
e7b7f2e9
io binding
f6927f7f
update benchmark
345f4e66
Add bert-base
da8bc503
Merge remote-tracking branch 'origin/main' into flash_v2_packed_mha
40a1f61c
merge main into branch for nuget fix
f7601235
Merge branch 'flash_v2_packed_mha' into flash_v2_no_cuda52
0dc8613d
update benchark to support more input formats
92652c33
Merge branch 'flash_v2_packed_mha' of https://github.com/microsoft/on…
ee2296fc
seq len threshold to trigger flash for packed qkv
e998af75
add back 2 lines
599d0198
flash attention flag in packed attention op test and a few more bench…
492c59f7
flash attention flag in packed attention op test and a few more bench…
6400a02b
Merge remote-tracking branch 'refs/remotes/origin/flash_v2_packed_mha…
1929de5d
specify TNLGv4 model for Turing Team in Benchmark
c880f084
remove env variable change from packed attention test
30c2f792
python lint
01443ef6
Merge remote-tracking branch 'origin/main' into flash_v2_packed_mha
6a06d9e1
Merge branch 'flash_v2_packed_mha' into flash_v2_no_cuda52
e1eb49aa
start work on group query attention
7605bb49
work on check input and group query attention cc
0697d19e
more work on gqa
b9784dc6
gqa working with causal or without causal
5e7286ec
push before rebase
cb0a96f8
merge with main
afb493ee
gqa with past builds
6053c865
gqa working with past kv
11608be8
Merge remote-tracking branch 'origin/main' into aciddelgado/group_que…
9d31ad13
some code cleaning
9d2f9226
some fixes and clean up
bdb38670
no dumper
362c6aeb
premerge main
04801df4
lint
3a11592a
mergemain
2941dbc7
Merge remote-tracking branch 'origin/main' into aciddelgado/group_que…
d78f4769
fix illegal access memory issue
2d0b960b
clean up
5b076f70
bytes
3bf777c5
merge main
cdc65dcc
gqa final touches
0e33dc1b
build fixes gqa
de64ff4e
lint
7a2ad7ca
benchmark gqa vs dmmha
7a476963
fix comments
437d23c3
start work bnsh
365d0b51
bsnh present
470a8a7e
Support for BNSH format
05d1c56b
bnsh attribute and benchmark
27dfac55
past-present bnsh, non-cache past-present.
6d681ee2
merge bnsh and no buff
0e76730d
lint and benchmark script
a7482edf
fix build issue
46f0ce45
fix build pipeline
86792141
pr cleanup
a0ec0eb2
int64 past sequence
b4082d10
small review changes p1
befdb2d1
clang-format and update documentation
3fb6b9c6
tianleiwu
dismissed these changes
on 2023-10-06
tianleiwu
changed the title Aciddelgado/group query attention [CUDA] GroupQueryAttention operator using FlashAttention 2 years ago
ignore whitespace when diff documentation
fcaba356
aciddelgado
dismissed their stale review
via fcaba356
2 years ago
ignore blank lines
3a06b64a
tianleiwu
dismissed these changes
on 2023-10-06
formatting whitespace
bbc47f0b
aciddelgado
dismissed their stale review
via bbc47f0b
2 years ago
tianleiwu
approved these changes
on 2023-10-09
aciddelgado
deleted the aciddelgado/group_query_attention branch 2 years ago
faxu
added triage:approved
faxu
added sdxl_llama
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub