Fix dimension order for FlashAttention-3 (#2391)
Summary:
We were passing tensors as B, H, S, D, which is expected by the Triton kernel, but FA3 expects B, S, H, D.
Also, the FA3 paper sizes inputs such that B*S=16k and H*Dhead=2048. Let's replicate that here. Also, print the complete tensor sizes in the output table, not just the sequence length.
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2391
Test Plan:
```
python run_benchmark.py triton --op flash_attention --only triton_tutorial_flash_v2,flash_v3 --metrics tflops --d-head 128
(Batch, Heads, SeqLen, Dhead triton_tutorial_flash_v2-tflops flash_v3-tflops
------------------------------ --------------------------------- -----------------
(32, 16, 512, 128) 328.882 430.32
(16, 16, 1024, 128) 389.723 517.099
(8, 16, 2048, 128) 433.721 565.544
(4, 16, 4096, 128) 458.359 590.233
(2, 16, 8192, 128) 470.732 614.234
(1, 16, 16384, 128) 476.753 620.73
```
Reviewed By: xuzhao9, manman-ren
Differential Revision: D60117034
Pulled By: bertmaher
fbshipit-source-id: bcdbae986c549948936c5ef429f7cb3298ee9ec1