[cpu] add sdpa choice and UT (#105131)
Feature RFC: https://github.com/pytorch/rfcs/pull/56.
Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen.
## Performance of the stack
### NanoGPT's SDPA kernel
Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket.
Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64.
Machine: SPR.
| Dtype | Causal | Mode | SDPA | Time (ms per iter) | Speedup |
| -------- | -------- | ------- | ------- | ------- | ------- |
| float32 | FALSE | Inference | Unfused | 3.081 | |
| | | | Flash attention | 1.665 | **1.85045** |
| float32 | TRUE | Inference | Unfused | 3.463 | |
| | | | Flash attention | 1.662 | **2.083634**|
| bfloat16 | FALSE | Inference | Unfused | 1.203 | |
| | | | Flash attention | 1.154 | **1.042461**|
| bfloat16 | TRUE | Inference | Unfused | 1.543 | |
| | | | Flash attention | 1.154 | **1.337088**|
| float32 | FALSE | Training | Unfused | 54.938 | |
| | | | Flash attention | 23.029 | **2.385601**|
| float32 | TRUE | Training | Unfused | 58.266 | |
| | | | Flash attention | 17.835 | **3.266947**|
| bfloat16 | FALSE | Training | Unfused | 18.924 | |
| | | | Flash attention | 18.886 | **1.002012**|
| bfloat16 | TRUE | Training | Unfused | 21.08 | |
| | | | Flash attention | 14.172 | **1.48744** |
### Stable Diffusion
Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md).
Mode: Inference; Machine: SPR.
| Dtype | SDPA | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup |
| -------- | -------- | ------- | ------- | ------- | ------- |
| float32 | Unfused | 1.63 | | 1139 | |
| | Flash attention | 1.983 | 1.216564 | 547.488 | **2.080411**|
| bfloat16 | Flash attention in IPEX | 4.784 | | 429.051 | |
| | Flash attention | 4.857 | 1.015259 | 408.823 | **1.049479**|
### LLM models of Torchbench
Dtype: float32; Mode: Inference, single socket; Machine: CPX.
Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024**
hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841**
hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636**
llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451**
Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR.
Model name | SDPA | Inductor_new | Inductor_old | Inductor Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324**
hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842**
hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355**
llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907**
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105131
Approved by: https://github.com/drisspg
ghstack dependencies: #104583, #104584, #103826, #104693, #104863, #107128