benchmark
c0409aaf - Add FlexAttention (#2443)

Commit
1 year ago
Add FlexAttention (#2443) Summary: ``` + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 0.24512 0.236768 0.257984 (16, 16, 1024, 128) 0.442944 0.41968 0.419008 (8, 16, 2048, 128) 0.845728 0.798688 0.74368 (4, 16, 4096, 128) 1.65574 1.55882 1.4041 (2, 16, 8192, 128) 3.27904 3.08669 2.73846 (1, 16, 16384, 128) 6.55098 6.14246 5.38931 + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --causal (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 0.199136 0.187424 0.201632 (16, 16, 1024, 128) 0.298208 0.278048 0.28912 (8, 16, 2048, 128) 0.51504 0.481088 0.46528 (4, 16, 4096, 128) 0.9584 0.890688 0.82928 (2, 16, 8192, 128) 1.84317 1.70605 1.55763 (1, 16, 16384, 128) 3.62157 3.34694 3.02374 + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --bwd (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 1.36323 1.30051 0.94192 (16, 16, 1024, 128) 1.89187 1.80678 1.40486 (8, 16, 2048, 128) 2.93325 2.83165 2.33082 (4, 16, 4096, 128) 5.05456 4.91002 4.19229 (2, 16, 8192, 128) 9.34131 9.10381 7.87952 (1, 16, 16384, 128) 17.9824 17.5658 15.4029 + python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --bwd --causal (Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency ------------------------------- -------------- ------------------ ------------------------ (32, 16, 512, 128) 1.14022 1.07926 0.838688 (16, 16, 1024, 128) 1.41398 1.33376 1.06554 (8, 16, 2048, 128) 1.97437 1.87725 1.53197 (4, 16, 4096, 128) 3.08925 2.95606 2.48413 (2, 16, 8192, 128) 5.28237 5.08138 4.37834 (1, 16, 16384, 128) 9.27056 8.94672 8.24531 ``` Pull Request resolved: https://github.com/pytorch/benchmark/pull/2443 Reviewed By: xuzhao9 Differential Revision: D61924897 Pulled By: bertmaher fbshipit-source-id: 40e18e6bee2b91e4a9826f5056a431950ee3495d
Author
Parents
Loading