onnxruntime
Flash Attention v2 MHA
#17227
Merged

Commits
  • add flash attention v2
    aciddelgado committed 2 years ago
  • remove bfloat16 kernels
    aciddelgado committed 2 years ago
  • namespace
    aciddelgado committed 2 years ago
  • remove commented code
    aciddelgado committed 2 years ago
  • remove backward code
    aciddelgado committed 2 years ago
  • path
    aciddelgado committed 2 years ago
  • flash v2 api and api call
    aciddelgado committed 2 years ago
  • api build fixes
    aciddelgado committed 2 years ago
  • builds succesfully
    aciddelgado committed 2 years ago
  • some fixes
    aciddelgado committed 2 years ago
  • use varlen to run Attention_Mask1D_FP16_B2 test
    aciddelgado committed 2 years ago
  • Packed MHA cleanup
    aciddelgado committed 2 years ago
  • flash attention runs
    aciddelgado committed 2 years ago
  • update
    aciddelgado committed 2 years ago
  • pre clean
    aciddelgado committed 2 years ago
  • clean PMHA
    aciddelgado committed 2 years ago
  • packed mha flash
    aciddelgado committed 2 years ago
  • remove extraneous changes
    aciddelgado committed 2 years ago
  • reviewed changes
    aciddelgado committed 2 years ago
  • reviewed changes
    aciddelgado committed 2 years ago
  • reviewed changes
    aciddelgado committed 2 years ago
  • reviewed changes
    aciddelgado committed 2 years ago
  • reviewed changes
    aciddelgado committed 2 years ago
  • namespace and USE_FLASH_ATTENTION flag
    aciddelgado committed 2 years ago
  • more compile flags flash
    aciddelgado committed 2 years ago
  • lint
    aciddelgado committed 2 years ago
  • gcc warnings in template
    aciddelgado committed 2 years ago
  • address tianlei comments
    aciddelgado committed 2 years ago
  • clean up
    tlwu@microsoft.com committed 2 years ago
  • cpplint
    tlwu@microsoft.com committed 2 years ago
  • refactoring
    tianleiwu committed 2 years ago
  • workspace as buffer. copyright and cgmanifest.
    aciddelgado committed 2 years ago
  • undo cgmanifest
    aciddelgado committed 2 years ago
  • enable flash attention in MultiHeadAttention op
    tianleiwu committed 2 years ago
  • fix attention test error in A100
    tianleiwu committed 2 years ago
  • namespace from flash to onnxruntime::flash
    tianleiwu committed 2 years ago
  • undo cgmanifest
    tianleiwu committed 2 years ago
  • add unit test
    tianleiwu committed 2 years ago
  • enable flash attention in Attention op
    tianleiwu committed 2 years ago
  • Merge branch 'main' into flash_v2_packed_mha
    tianleiwu committed 2 years ago
  • set proper nvcc threads to avoid OOM
    tianleiwu committed 2 years ago
  • --nvcc_threads=1 in build_cuda_c_api_package.sh
    tianleiwu committed 2 years ago
  • test script segfaults
    aciddelgado committed 2 years ago
  • pass cuda device prop to flash attention
    tianleiwu committed 2 years ago
  • add requirements for test_flash_attn.py
    tianleiwu committed 2 years ago
  • remove nvcc_threads logic
    tianleiwu committed 2 years ago
  • flash attn test, pmha works, mha crashes
    aciddelgado committed 2 years ago
  • check head size for efficient attention
    tianleiwu committed 2 years ago
  • lint except lambda assignment
    aciddelgado committed 2 years ago
  • lint fix
    aciddelgado committed 2 years ago
  • line length < 120
    tianleiwu committed 2 years ago
  • flash v2 update
    aciddelgado committed 2 years ago
  • formatting
    aciddelgado committed 2 years ago
  • flash benchmark script
    aciddelgado committed 2 years ago
  • merge with main
    aciddelgado committed 2 years ago
  • io binding
    aciddelgado committed 2 years ago
  • update benchmark
    tianleiwu committed 2 years ago
  • Add bert-base
    tianleiwu committed 2 years ago
  • Merge remote-tracking branch 'origin/main' into flash_v2_packed_mha
    aciddelgado committed 2 years ago
  • merge main into branch for nuget fix
    aciddelgado committed 2 years ago
  • update benchark to support more input formats
    tianleiwu committed 2 years ago
  • Merge branch 'flash_v2_packed_mha' of https://github.com/microsoft/onnxruntime into flash_v2_packed_mha
    tianleiwu committed 2 years ago
  • seq len threshold to trigger flash for packed qkv
    tianleiwu committed 2 years ago
  • add back 2 lines
    tianleiwu committed 2 years ago
  • flash attention flag in packed attention op test and a few more benchmarks for roli
    aciddelgado committed 2 years ago
  • flash attention flag in packed attention op test and a few more benchmarks for roli
    aciddelgado committed 2 years ago
  • Merge remote-tracking branch 'refs/remotes/origin/flash_v2_packed_mha' into flash_v2_packed_mha
    aciddelgado committed 2 years ago
  • specify TNLGv4 model for Turing Team in Benchmark
    aciddelgado committed 2 years ago
  • remove env variable change from packed attention test
    aciddelgado committed 2 years ago
  • python lint
    aciddelgado committed 2 years ago
Loading