onnxruntime
a6887f17 - Refactor schema extraction and output unflattening (#16894)

Commit
2 years ago
Refactor schema extraction and output unflattening (#16894) ### Motivation and Context When we handle PyTorch models' inputs in different places (ORTModule or others), it's common for us to flatten a structured data into a 1-D tensor list (required by lib for example torch.onnx.export, torch.autograd.Function.forward or ORT inference session), then do subsequent work, then unflatten back to original hierarchy as returned values. DeepStage3 hooks support work also need such a lib to do similar things, so I was proposing to extract this pair of APIs in training/utils/, which can be more used more generally. Also a comprehensive set of test data are used for testing unflatten/flatten in unit tests. Let me know if you have any other suggestions. ### Refactor schema extraction and output unflattening Move `_extract_schema` and `unflatten_user_output` in `orttraining/orttraining/python/training/ortmodule/_io.py` . to `extract_data_and_schema` and `unflatten_data_using_schema` in `orttraining/orttraining/python/training/utils/torch_io_helper.py` as shared libs, which can be used later by other features (deepspeed stage 3 hook rewrite). While there are still a few duplicated logic handling flatten with different task by recursively loop the data struct, will change them step by step in case of heavy review efforts.
Author
Parents
Loading