onnxruntime
58af36b4 - Fuse ScaledSum and its backward BatchScale (#16517)

Commit
2 years ago
Fuse ScaledSum and its backward BatchScale (#16517) ### Fuse ScaledSum and its backward BatchScale For deberta models, there is a pattern a / scalar_0 + b / scalar_1 + c / scalar_2 We can fuse this into ScaledSum operator, taking 2(or 3) inputs, and 2(or 3) attributes scalar, generating one output. For the backward, the gradient of a, b and c will be computed with BatchScale. ### Benchmark on 8x32GV100 ```bash torchrun --nproc_per_node=8 examples/onnxruntime/training/language-modeling/run_mlm.py --model_name_or_path microsoft/deberta-v3-large --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --do_train --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused --max_steps 400 --logging_steps 1 --use_module_with_loss --deepspeed aml_ds_config_zero_1.json --per_device_train_batch_size 10 ``` #### Main Branch ``` Total overhead: 127954ms where export takes 116489ms. epoch = 14.29 train_loss = 4.9803 train_runtime = 0:10:27.29 train_samples = 2223 train_samples_per_second = 51.013 train_steps_per_second = 0.638 throughput per GPU = 14.29* 2223/ (627.29 - 127.954) / 8 (gpu) = 7.952 samples/second ``` #### This PR ``` Total overhead: 128761ms where export takes 118510ms. ***** train metrics ***** epoch = 14.29 train_loss = 4.6144 train_runtime = 0:10:04.31 train_samples = 2223 train_samples_per_second = 52.953 train_steps_per_second = 0.662 throughput per GPU = 14.29*2223 / (604.31 - 128.761) / 8 = 8.350 samples/second ``` 5.x% performance gains.
Author
Parents
Loading