onnxruntime
87076553 - [CUDA] Add SparseAttention kernel for sm=75 (#20531)

Commit
1 year ago
[CUDA] Add SparseAttention kernel for sm=75 (#20531) ### Description Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
Author
Parents
Loading