Flash Attention v2 MHA #17227
add flash attention v2
f39854de
remove bfloat16 kernels
aa1efa88
namespace
1fe60ae4
remove commented code
6cb6eba2
remove backward code
893b77ec
path
1195d3f6
flash v2 api and api call
4e59d957
api build fixes
b0f7f47d
builds succesfully
0e79ec54
some fixes
2f61ac2d
use varlen to run Attention_Mask1D_FP16_B2 test
372141db
Packed MHA cleanup
95a7e358
flash attention runs
52553a77
update
018ecb6a
pre clean
632451eb
clean PMHA
98a4e5d0
packed mha flash
12f8db86
remove extraneous changes
91c1ab8f
reviewed changes
2fc7d86d
reviewed changes
7cf03591
reviewed changes
f6554a13
reviewed changes
7a124878
reviewed changes
33a43aa7
tianleiwu
changed the title Flash v2 packed mha Flash Attention v2 packed mha 2 years ago
namespace and USE_FLASH_ATTENTION flag
24524ccb
more compile flags flash
d9002f70
lint
3be4af27
gcc warnings in template
3798446f
address tianlei comments
666d3bf3
snnn
commented
on 2023-08-22
refactoring
cad2901a
workspace as buffer. copyright and cgmanifest.
fdb189bd
undo cgmanifest
14d7db2e
enable flash attention in MultiHeadAttention op
aba9cebf
fix attention test error in A100
a06da2eb
namespace from flash to onnxruntime::flash
ab015e63
undo cgmanifest
866414cf
add unit test
6d682ab0
enable flash attention in Attention op
300019e3
tianleiwu
changed the title Flash Attention v2 packed mha Flash Attention v2 MHA 2 years ago
Merge branch 'main' into flash_v2_packed_mha
7e50d9a2
set proper nvcc threads to avoid OOM
7481ec34
--nvcc_threads=1 in build_cuda_c_api_package.sh
2823476c
test script segfaults
c8801479
pass cuda device prop to flash attention
d2a040b3
add requirements for test_flash_attn.py
8b92a9b1
remove nvcc_threads logic
a0393e2c
flash attn test, pmha works, mha crashes
0830925a
check head size for efficient attention
1c967391
faxu
removed release:1.16
lint except lambda assignment
6d8e43d2
lint fix
4ceca3ba
line length < 120
3bfa3b55
flash v2 update
bdb17d5c
formatting
bfee28ef
flash benchmark script
a54a7b98
merge with main
b064a010
io binding
f6927f7f
update benchmark
345f4e66
Add bert-base
da8bc503
Merge remote-tracking branch 'origin/main' into flash_v2_packed_mha
40a1f61c
merge main into branch for nuget fix
f7601235
update benchark to support more input formats
92652c33
Merge branch 'flash_v2_packed_mha' of https://github.com/microsoft/on…
ee2296fc
seq len threshold to trigger flash for packed qkv
e998af75
add back 2 lines
599d0198
flash attention flag in packed attention op test and a few more bench…
492c59f7
flash attention flag in packed attention op test and a few more bench…
6400a02b
Merge remote-tracking branch 'refs/remotes/origin/flash_v2_packed_mha…
1929de5d
specify TNLGv4 model for Turing Team in Benchmark
c880f084
remove env variable change from packed attention test
30c2f792
tianleiwu
dismissed these changes
on 2023-08-31
python lint
01443ef6
aciddelgado
dismissed their stale review
via 01443ef6
2 years ago
yufenglee
approved these changes
on 2023-08-31
tianleiwu
approved these changes
on 2023-08-31
aciddelgado
deleted the flash_v2_packed_mha branch 2 years ago
natke
added triage:approved
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub