onnxruntime
6b7bce5e - Model post process for zero stage3 training (#17187)

Commit
2 years ago
Model post process for zero stage3 training (#17187) ### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Author
Parents
Loading