Add FlexAttention (#2443)
Summary:
```
+ python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention
(Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency
------------------------------- -------------- ------------------ ------------------------
(32, 16, 512, 128) 0.24512 0.236768 0.257984
(16, 16, 1024, 128) 0.442944 0.41968 0.419008
(8, 16, 2048, 128) 0.845728 0.798688 0.74368
(4, 16, 4096, 128) 1.65574 1.55882 1.4041
(2, 16, 8192, 128) 3.27904 3.08669 2.73846
(1, 16, 16384, 128) 6.55098 6.14246 5.38931
+ python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --causal
(Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency
------------------------------- -------------- ------------------ ------------------------
(32, 16, 512, 128) 0.199136 0.187424 0.201632
(16, 16, 1024, 128) 0.298208 0.278048 0.28912
(8, 16, 2048, 128) 0.51504 0.481088 0.46528
(4, 16, 4096, 128) 0.9584 0.890688 0.82928
(2, 16, 8192, 128) 1.84317 1.70605 1.55763
(1, 16, 16384, 128) 3.62157 3.34694 3.02374
+ python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --bwd
(Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency
------------------------------- -------------- ------------------ ------------------------
(32, 16, 512, 128) 1.36323 1.30051 0.94192
(16, 16, 1024, 128) 1.89187 1.80678 1.40486
(8, 16, 2048, 128) 2.93325 2.83165 2.33082
(4, 16, 4096, 128) 5.05456 4.91002 4.19229
(2, 16, 8192, 128) 9.34131 9.10381 7.87952
(1, 16, 16384, 128) 17.9824 17.5658 15.4029
+ python ./run_benchmark.py triton --op flash_attention --d-head 128 --only sdpa,flash_v2,flex_attention --bwd --causal
(Batch, Heads, SeqLen, Dhead) sdpa-latency flash_v2-latency flex_attention-latency
------------------------------- -------------- ------------------ ------------------------
(32, 16, 512, 128) 1.14022 1.07926 0.838688
(16, 16, 1024, 128) 1.41398 1.33376 1.06554
(8, 16, 2048, 128) 1.97437 1.87725 1.53197
(4, 16, 4096, 128) 3.08925 2.95606 2.48413
(2, 16, 8192, 128) 5.28237 5.08138 4.37834
(1, 16, 16384, 128) 9.27056 8.94672 8.24531
```
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2443
Reviewed By: xuzhao9
Differential Revision: D61924897
Pulled By: bertmaher
fbshipit-source-id: 40e18e6bee2b91e4a9826f5056a431950ee3495d