pytorch
547bef11 - tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)

Commit
2 years ago
tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644) High level approach: 1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK) 2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR). 2a. I did a bunch of feature augmentation to capture interactions between features. The heuristic I ended up with is: ``` use_flash = seq_len / (num_heads * batch_size) > 6 ``` TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy. My new heuristic achieves 94% accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644 Approved by: https://github.com/ngimel, https://github.com/drisspg
Author
Committer
Parents
Loading