Stable Diffusion 3.x and Flux Optimization (#22986)
### Description
It has dependency on the following PRs:
- https://github.com/microsoft/onnxruntime/pull/23297
Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models
(fp32 or fp16).
- [x] Update optimize_pipeline script
- [x] Update benchmkark script
- [x] Update document about Stable Diffusion 3.x and Flux 1.0 models
- [x] Add graph optimizations for MMDit model
- [x] FastGelu fusion
- [x] RMSNorm fusion
- [x] MultiHeadAttention fusion
- [x] Add graph optimizations for Flux transformer models
- [x] MultiHeadAttention fusion
- [x] Update graph optimizations for t5
- [x] Add tests
Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models:
```
python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16
Optimize flux1_schnell_onnx/fp32/transformer/model.onnx ...
Fused LayerNormalization: 115
Fused SimplifiedLayerNormalization: 152
Fused FastGelu: 76
Fused MultiHeadAttention: 57
```
### H100 Benchmark Results
* GPU: NVIDIA H100 80GB HBM3
* Image Size: 1024x1024
* Batch Size: 1
Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB)
-- | -- | -- | -- | -- | --
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 8.198 | 37,603
Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 10.762 | 41,469
Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 10.891 | 43,545
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 12.339 | 36,651
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 0.775 | 37,857
Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 0.931 | 41,433
Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 0.939 | 43,809
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 1.120 | 36,629
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 7.466 | 32,217
SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 10.275 | 36,609
SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 10.283 | 36,729
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 11.615 | 31,517
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 3.240 | 21,143
SD 3.5 Medium | 50 | FP16+BF16 | Optimum (ORT) | 4.799 | 25,097
SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 4.838 | 25,109
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 5.582 | 20,489
### A100 Benchmark Results
* GPU: A100-SXM4-80GB
* Image Size: 1024x1024
* Batch Size: 1
Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB)
-- | -- | -- | -- | -- | --
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 17.593 | 37,723
Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 21.918 | 41,348
Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 22.060 | 44,860
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 24.267 | 36,847
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 1.627 | 37,881
Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 1.884 | 41,537
Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 1.902 | 44,858
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 2.162 | 36,831
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 15.881 | 32,307
SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 19.837 | 36,451
SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 19.964 | 36,461
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 22.477 | 31,513
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 6.476 | 21,341
SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 8.775 | 25,183
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 10.057 | 20,433
### Future Works
* Triton kernel for matrix multiplication and auto tuning.
* FP8/Int8 quantization
### Motivation and Context
SD 3.5 Architecture:
https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/resolve/main/mmdit-x.png