onnxruntime
ebd0368b - Make Flash Attention work on Windows (#21015)

Commit
1 year ago
Make Flash Attention work on Windows (#21015) ### Description Previously, Flash Attention only worked on Linux systems. This PR will make it work and enable it to be built and run on Windows. Limitations of Flash Attention in Windows: Requires CUDA 12. ### Motivation and Context This will significantly increase the performance of Windows-based LLM's with hardware sm>=80. To illustrate the improvement of Flash Attention over Memory Efficient Attention, here are some average benchmark numbers for the GQA operator, run with configurations based on several recent models (Llama, Mixtral, Phi-3). The benchmarks were obtained on RTX4090 GPU using the test script located at (onnxruntime/test/python/transformers/benchmark_gqa_windows.py). * Clarifying Note: These benchmarks are just for the GQA operator, not the entire model. ### Memory Efficient Attention Kernel Benchmarks: | Model Name | Max Sequence Length | Inference Interval (ms) | Throughput (samples/second) | |----------------------------------------|---------------------|-------------------------|-----------------------------| | Llama3-8B (Average Prompt) | 8192 | 0.19790525 | 13105.63425 | | Llama3-8B (Average Token) | 8192 | 0.207775538 | 12025.10172 | | Llama3-70B (Average Prompt) | 8192 | 0.216049167 | 11563.31185 | | Llama3-70B (Average Token) | 8192 | 0.209730731 | 12284.38149 | | Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.371928785 | 7031.440056 | | Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.2996659 | 7607.947159 | | Phi-3-mini-128k (Average Prompt) | 131072 | 0.183195867 | 15542.0852 | | Phi-3-mini-128k (Average Token) | 131072 | 0.198215688 | 12874.53494 | | Phi-3-small-128k (Average Prompt) | 65536 | 2.9884929 | 2332.584142 | | Phi-3-small-128k (Average Token) | 65536 | 0.845072406 | 2877.85822 | | Phi-3-medium-128K (Average Prompt) | 32768 | 0.324974429 | 8094.909517 | | Phi-3-medium-128K (Average Token) | 32768 | 0.263662567 | 8978.463687 | ### Flash Attention Kernel Benchmarks: | Model Name | Max Sequence Length | Inference Interval (ms) | Throughput (samples/second) | |--------------------------------------|---------------------|-------------------------|-----------------------------| | Llama3-8B (Average Prompt) | 8192 | 0.163566292 | 16213.69057 | | Llama3-8B (Average Token) | 8192 | 0.161643692 | 16196.14715 | | Llama3-70B (Average Prompt) | 8192 | 0.160510375 | 17448.67753 | | Llama3-70B (Average Token) | 8192 | 0.169427308 | 14702.62043 | | Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.164121964 | 15618.51301 | | Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.1715865 | 14524.32273 | | Phi-3-mini-128k (Average Prompt) | 131072 | 0.167527167 | 14576.725 | | Phi-3-mini-128k (Average Token) | 131072 | 0.175940594 | 15762.051 | | Phi-3-small-128k (Average Prompt) | 65536 | 0.162719733 | 17824.494 | | Phi-3-small-128k (Average Token) | 65536 | 0.14977525 | 16749.19858 | | Phi-3-medium-128K (Average Prompt) | 32768 | 0.156490786 | 17679.2513 | | Phi-3-medium-128K (Average Token) | 32768 | 0.165333833 | 14932.26079 | Flash Attention is consistently faster for every configuration we benchmarked, with improvements in our trials ranging from ~20% to ~650%. In addition to these improvements in performance, Flash Attention has better memory usage. For example, Memory Efficient Attention cannot handle a max sequence length higher than 32,768, but Flash Attention can handle max sequence lengths at least as high as 131,072. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Author
Parents
Loading