onnxruntime
c55c6255 - Eliminate safe nodes that are followed by a shape node. (#16065)

Commit
2 years ago
Eliminate safe nodes that are followed by a shape node. (#16065) ### Description Eliminate Cast operator if Shape is the next one. ### Motivation and Context #### Cast When working with onnx opset 15 and above, the shape operator now accepts all types of variables. This change is documented in the [onnx Changelog](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15). As a result, casting variables right before the shape operation becomes unnecessary. Removing these unnecessary casts will improve the graph and potentially provide better performance gains. ## Results On : torchrun examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path gpt2 --do_train --overwrite_output_dir --output_dir ./outputs/ --seed 1337 --fp16 True --per_device_train_batch_size 4 --num_train_epochs 1 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --learning_rate 2e-5 --report_to none --optim adamw_ort_fused without changes: ***** train metrics ***** epoch = 1.0 train_loss = 3.2981 train_runtime = 0:02:13.29 train_samples = 2318 train_samples_per_second = 17.39 train_steps_per_second = 4.351 With my changes: ***** train metrics ***** epoch = 1.0 train_loss = 3.2981 train_runtime = 0:02:08.98 train_samples = 2318 train_samples_per_second = 17.971 train_steps_per_second = 4.497 We see around 3% gain. --------- Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Author
Parents
Loading