onnxruntime
85facd67 - [CUDA] Benchmark GQA on popular LLM models (#20646)

Commit
1 year ago
[CUDA] Benchmark GQA on popular LLM models (#20646) ### Description Update benchmark_gqa.py to test latency on popular models (like Llama3-8b, Llama3-70b, Mixtral-8x22B-v0.1 and Phi-3 etc). Note that this is latency of just one GroupQueryAttention node, not the whole model. For example, packed QKV might need more time in GQA, but it is faster in MatMul of input projection, the overall effect is not measured here. Example output in A100-SXM4-80GB : ``` prompt-sm80-Llama3-8B-b1-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.019073 0.016264 1 32.0 0.017768 0.017957 2 64.0 0.023304 0.023192 3 128.0 0.032541 0.031348 4 256.0 0.048329 0.049484 5 512.0 0.095294 0.095950 6 1024.0 0.228050 0.228980 7 2048.0 0.663820 0.663308 8 4096.0 2.243657 2.242999 9 8192.0 8.197120 8.186282 token-sm80-Llama3-8B-b1-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018516 0.015398 1 32.0 0.015687 0.016079 2 64.0 0.016115 0.016053 3 128.0 0.018727 0.019413 4 256.0 0.036373 0.035962 5 512.0 0.041701 0.042203 6 1024.0 0.053730 0.053750 7 2048.0 0.076382 0.075707 8 4096.0 0.121876 0.121802 9 8191.0 0.211292 0.211254 prompt-sm80-Llama3-8B-b4-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.024558 0.022070 1 32.0 0.021276 0.021406 2 64.0 0.044172 0.027789 3 128.0 0.069100 0.059071 4 256.0 0.146569 0.106717 5 512.0 0.270472 0.244461 6 1024.0 0.690024 0.692501 7 2048.0 2.308546 2.325453 8 4096.0 8.724295 8.957337 9 8192.0 39.030785 41.381378 token-sm80-Llama3-8B-b4-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018893 0.018611 1 32.0 0.018124 0.018190 2 64.0 0.018115 0.018156 3 128.0 0.023291 0.023733 4 256.0 0.038357 0.038351 5 512.0 0.047117 0.047792 6 1024.0 0.066272 0.065409 7 2048.0 0.104196 0.104527 8 4096.0 0.180557 0.180424 9 8191.0 0.332545 0.332714 prompt-sm80-Llama3-70B-b1-h64_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.040974 0.015852 1 32.0 0.017839 0.018615 2 64.0 0.023956 0.022704 3 128.0 0.044622 0.035229 4 256.0 0.080241 0.075237 5 512.0 0.143457 0.144322 6 1024.0 0.380473 0.381731 7 2048.0 1.217328 1.214505 8 4096.0 4.305315 4.286324 9 8192.0 15.918250 15.933440 token-sm80-Llama3-70B-b1-h64_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016148 0.015612 1 32.0 0.015616 0.015616 2 64.0 0.016082 0.016070 3 128.0 0.019470 0.019130 4 256.0 0.036617 0.037296 5 512.0 0.042087 0.042176 6 1024.0 0.053704 0.053587 7 2048.0 0.076918 0.076365 8 4096.0 0.122534 0.121984 9 8191.0 0.212961 0.213330 prompt-sm80-Llama3-70B-b4-h64_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.031137 0.026270 1 32.0 0.030938 0.032009 2 64.0 0.040833 0.059118 3 128.0 0.084899 0.085482 4 256.0 0.163951 0.166310 5 512.0 0.420436 0.423721 6 1024.0 1.282019 1.283482 7 2048.0 4.397661 4.420121 8 4096.0 16.931839 17.456945 9 8192.0 77.896706 83.007484 token-sm80-Llama3-70B-b4-h64_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.026106 0.026061 1 32.0 0.025678 0.025589 2 64.0 0.025438 0.025965 3 128.0 0.033879 0.033320 4 256.0 0.058078 0.057656 5 512.0 0.078010 0.078153 6 1024.0 0.106353 0.098079 7 2048.0 0.160039 0.159153 8 4096.0 0.282527 0.283346 9 8191.0 0.546207 0.542135 prompt-sm80-Mistral-7B-v0.1-b1-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.015722 0.015655 0.015666 0.016150 1 32.0 0.018590 0.018562 0.018136 0.024617 2 64.0 0.022480 0.023085 0.023184 0.023160 3 128.0 0.029948 0.030581 0.030839 0.031464 4 256.0 0.048532 0.049099 0.049424 0.049408 5 512.0 0.095096 0.095665 0.096174 0.096175 6 1024.0 0.228606 0.228942 0.228434 0.229568 7 2048.0 0.660832 0.661943 0.662170 0.663979 8 4096.0 2.238001 2.243999 2.242243 2.241707 9 8192.0 8.173824 6.147072 8.187648 6.152822 10 16384.0 33.826305 14.486015 34.849792 14.938283 11 32768.0 176.702469 32.725330 184.309753 34.736130 token-sm80-Mistral-7B-v0.1-b1-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.015407 0.016042 0.016030 0.015429 1 32.0 0.015525 0.016115 0.016768 0.016052 2 64.0 0.015556 0.016079 0.015383 0.016008 3 128.0 0.019302 0.018644 0.018680 0.019278 4 256.0 0.036924 0.035900 0.036753 0.036786 5 512.0 0.041482 0.041434 0.041646 0.042238 6 1024.0 0.053587 0.052972 0.052888 0.052856 7 2048.0 0.075749 0.075807 0.076528 0.075945 8 4096.0 0.122053 0.122016 0.122115 0.122216 9 8192.0 0.212069 0.121317 0.211919 0.121087 10 16384.0 0.394036 0.121202 0.393661 0.121483 11 32767.0 0.757216 0.124326 0.757659 0.124157 prompt-sm80-Mistral-7B-v0.1-b4-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.018418 0.018911 0.023387 0.019256 1 32.0 0.021085 0.021132 0.022143 0.022251 2 64.0 0.026743 0.026770 0.027942 0.027714 3 128.0 0.057922 0.058483 0.058800 0.059402 4 256.0 0.105927 0.104876 0.106695 0.105996 5 512.0 0.242958 0.242543 0.244599 0.244774 6 1024.0 0.689321 0.689347 0.691759 0.692334 7 2048.0 2.308250 2.304410 2.321587 2.317875 8 4096.0 8.705210 8.713682 8.927418 8.903866 9 8192.0 39.630848 28.227926 41.604607 29.648554 10 16384.0 175.553543 61.422592 183.384064 64.560127 11 32768.0 772.296692 132.006912 813.537292 138.996735 token-sm80-Mistral-7B-v0.1-b4-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.018127 0.018691 0.018661 0.018681 1 32.0 0.018183 0.018812 0.018739 0.018759 2 64.0 0.018081 0.018116 0.018136 0.018153 3 128.0 0.023257 0.023146 0.023114 0.023103 4 256.0 0.038665 0.038102 0.038120 0.038759 5 512.0 0.047181 0.047156 0.047012 0.046382 6 1024.0 0.066047 0.066103 0.066604 0.066076 7 2048.0 0.104427 0.103770 0.103799 0.103807 8 4096.0 0.180951 0.180373 0.180173 0.180154 9 8192.0 0.334018 0.180801 0.333269 0.180690 10 16384.0 0.638682 0.180965 0.638543 0.180202 11 32767.0 1.249536 0.184779 1.249963 0.184624 prompt-sm80-Mixtral-8x22B-v0.1-b1-h48_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.015699 0.015563 1 32.0 0.017931 0.017719 2 64.0 0.029975 0.022875 3 128.0 0.031038 0.055747 4 256.0 0.050191 0.050845 5 512.0 0.125187 0.122813 6 1024.0 0.304004 0.301824 7 2048.0 0.936454 0.931546 8 4096.0 3.264547 3.255931 9 8192.0 12.062719 12.030080 10 16384.0 49.018368 48.970749 11 32768.0 261.211151 254.461945 12 65536.0 1221.138428 1197.559814 token-sm80-Mixtral-8x22B-v0.1-b1-h48_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.015980 0.016024 1 32.0 0.015440 0.016165 2 64.0 0.015987 0.015979 3 128.0 0.020837 0.018715 4 256.0 0.036240 0.036747 5 512.0 0.042477 0.041813 6 1024.0 0.052950 0.052956 7 2048.0 0.076084 0.076691 8 4096.0 0.122233 0.121540 9 8192.0 0.212469 0.212433 10 16384.0 0.394937 0.394996 11 32768.0 0.757285 0.757257 12 65535.0 1.484867 1.485015 prompt-sm80-Mixtral-8x22B-v0.1-b4-h48_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.024119 0.018755 1 32.0 0.022214 0.022267 2 64.0 0.028045 0.027562 3 128.0 0.062894 0.079766 4 256.0 0.135146 0.134483 5 512.0 0.331323 0.329094 6 1024.0 0.984576 0.982221 7 2048.0 3.353564 3.351021 8 4096.0 12.762113 12.778350 9 8192.0 58.599422 57.704449 10 16384.0 263.392242 258.709503 11 32768.0 1155.789795 1128.622070 12 65536.0 5014.187012 4874.590332 token-sm80-Mixtral-8x22B-v0.1-b4-h48_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018148 0.018813 1 32.0 0.018929 0.018840 2 64.0 0.018745 0.018232 3 128.0 0.023864 0.023822 4 256.0 0.038603 0.038694 5 512.0 0.048347 0.047630 6 1024.0 0.066957 0.067392 7 2048.0 0.105094 0.105058 8 4096.0 0.181941 0.181808 9 8192.0 0.334227 0.334324 10 16384.0 0.640429 0.640961 11 32768.0 1.267897 1.269120 12 65535.0 2.534238 2.504408 prompt-sm80-Phi-3-mini-128k-b1-h32_32x96-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016112 0.026949 1 32.0 0.016486 0.017284 2 64.0 0.020910 0.020994 3 128.0 0.029306 0.029452 4 256.0 0.044604 0.044642 5 512.0 0.090079 0.086868 6 1024.0 0.208169 0.208094 7 2048.0 0.604687 0.607910 8 4096.0 2.029056 2.046771 9 8192.0 7.792128 7.906303 10 16384.0 34.271233 34.418175 11 32768.0 160.377853 159.980545 12 65536.0 733.443054 734.722046 token-sm80-Phi-3-mini-128k-b1-h32_32_d96-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016339 0.015718 1 32.0 0.016572 0.015964 2 64.0 0.016182 0.016192 3 128.0 0.019373 0.018621 4 256.0 0.021856 0.022463 5 512.0 0.028943 0.028888 6 1024.0 0.041124 0.041104 7 2048.0 0.067668 0.067542 8 4096.0 0.117528 0.117447 9 8192.0 0.216241 0.215492 10 16384.0 0.413434 0.414047 11 32768.0 0.811085 0.810612 12 65536.0 1.606189 1.606458 13 131071.0 3.193037 3.192491 prompt-sm80-Phi-3-mini-128k-b4-h32_32x96-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.019385 0.019403 1 32.0 0.019801 0.020006 2 64.0 0.025958 0.025376 3 128.0 0.056445 0.055909 4 256.0 0.103180 0.102221 5 512.0 0.244224 0.244360 6 1024.0 0.703066 0.709327 7 2048.0 2.307456 2.335001 8 4096.0 8.334522 8.406760 9 8192.0 33.340416 33.758209 10 16384.0 144.141312 145.005569 11 32768.0 655.496216 655.656982 12 65536.0 2981.463135 2984.790039 token-sm80-Phi-3-mini-128k-b4-h32_32_d96-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018701 0.018185 1 32.0 0.020625 0.019213 2 64.0 0.019936 0.019943 3 128.0 0.023648 0.023689 4 256.0 0.030309 0.030305 5 512.0 0.043501 0.043801 6 1024.0 0.067314 0.068014 7 2048.0 0.108649 0.108134 8 4096.0 0.186053 0.186848 9 8192.0 0.339973 0.339742 10 16384.0 0.643288 0.644366 11 32768.0 1.261468 1.261510 12 65536.0 2.502252 2.501820 13 131071.0 4.990437 4.989521 prompt-sm80-Phi-3-small-128k-b1-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.025280 0.023331 1 32.0 0.023071 0.025931 2 64.0 0.022883 0.026258 3 128.0 0.030658 0.031445 4 256.0 0.057659 0.057073 5 512.0 0.095589 0.106579 6 1024.0 0.228532 0.229402 7 2048.0 0.662315 0.663349 8 4096.0 2.242885 2.248095 9 8192.0 8.194646 8.180395 10 16384.0 33.926659 35.130882 11 32768.0 175.320068 184.967163 12 65536.0 810.447876 847.632385 token-sm80-Phi-3-small-128k-b1-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.015517 0.016038 1 32.0 0.016372 0.015477 2 64.0 0.015472 0.016016 3 128.0 0.019291 0.018664 4 256.0 0.036250 0.035990 5 512.0 0.041691 0.042238 6 1024.0 0.053730 0.053126 7 2048.0 0.075912 0.076439 8 4096.0 0.121336 0.121334 9 8192.0 0.213104 0.212443 10 16384.0 0.394353 0.394272 11 32768.0 0.756965 0.757017 12 65536.0 1.484548 1.485371 13 131071.0 2.939200 2.939552 prompt-sm80-Phi-3-small-128k-b4-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.044326 0.019298 1 32.0 0.021840 0.021408 2 64.0 0.027492 0.027802 3 128.0 0.058128 0.059431 4 256.0 0.104300 0.106019 5 512.0 0.242562 0.244948 6 1024.0 0.689614 0.692305 7 2048.0 2.297931 2.312857 8 4096.0 8.654848 8.843170 9 8192.0 38.770176 40.929279 10 16384.0 175.572998 183.692291 11 32768.0 780.126221 820.551697 12 65536.0 3357.564941 3488.527344 token-sm80-Phi-3-small-128k-b4-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018061 0.017995 1 32.0 0.018225 0.018851 2 64.0 0.018203 0.018104 3 128.0 0.023161 0.023651 4 256.0 0.038421 0.037673 5 512.0 0.047590 0.046938 6 1024.0 0.065639 0.066055 7 2048.0 0.103545 0.103581 8 4096.0 0.180461 0.179998 9 8192.0 0.332667 0.332564 10 16384.0 0.638503 0.639094 11 32768.0 1.249180 1.249479 12 65536.0 2.469457 2.471666 13 131071.0 4.915362 4.914499 prompt-sm80-Phi-3-medium-128K-b1-h40_10x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.025759 0.016318 1 32.0 0.018282 0.018111 2 64.0 0.022642 0.022978 3 128.0 0.030860 0.037988 4 256.0 0.055703 0.050318 5 512.0 0.113465 0.113776 6 1024.0 0.267678 0.268292 7 2048.0 0.795202 0.797222 8 4096.0 2.737953 2.740435 9 8192.0 10.101760 10.149092 10 16384.0 43.326466 43.990013 11 32768.0 230.886398 229.886978 12 65536.0 1067.412476 1052.922852 token-sm80-Phi-3-medium-128K-b1-h40_10_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016122 0.015582 1 32.0 0.015594 0.016262 2 64.0 0.016099 0.015512 3 128.0 0.018708 0.019510 4 256.0 0.037582 0.036341 5 512.0 0.042411 0.041894 6 1024.0 0.053278 0.053914 7 2048.0 0.076553 0.076636 8 4096.0 0.121539 0.121610 9 8192.0 0.212083 0.212377 10 16384.0 0.395086 0.395280 11 32768.0 0.757879 0.757888 12 65536.0 1.486093 1.486915 13 131071.0 2.941728 2.941408 prompt-sm80-Phi-3-medium-128K-b4-h40_10x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.019448 0.018872 1 32.0 0.022290 0.022380 2 64.0 0.027986 0.027955 3 128.0 0.062699 0.062175 4 256.0 0.124868 0.125247 5 512.0 0.298873 0.298169 6 1024.0 0.862584 0.863467 7 2048.0 2.944640 2.957824 8 4096.0 11.318656 11.390720 9 8192.0 52.606976 52.019199 10 16384.0 232.616959 230.360062 11 32768.0 1024.171997 1019.540466 12 65536.0 4377.362305 4354.510742 token-sm80-Phi-3-medium-128K-b4-h40_10_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018192 0.018175 1 32.0 0.018999 0.018319 2 64.0 0.018447 0.018897 3 128.0 0.023863 0.023195 4 256.0 0.037712 0.038192 5 512.0 0.048863 0.048548 6 1024.0 0.067244 0.066473 7 2048.0 0.105203 0.105021 8 4096.0 0.180712 0.180429 9 8192.0 0.334948 0.334734 10 16384.0 0.640662 0.639709 11 32768.0 1.252196 1.251684 12 65536.0 2.474927 2.474280 13 131071.0 4.930829 4.959340 ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Author
Parents
Loading