DeepSpeed
6f1a1c04 - Restore real inputs for recompilation (#7356)

Commit
313 days ago
Restore real inputs for recompilation (#7356) This PR keeps some of real inputs given to the custom backend for DeepCompile. DeepCompile expects that the custom backend at TorchFX graph level is always called when recompilation happens. In some cases, however, only the Aten-level backend is called. As the Aten-level backend uses real inputs saved by TorchFX-level backend, we need to keep the real inputs for recompilation. Currently we discard the real inputs after the Aten-level backend uses it as the real inputs are often too large to keep in GPU memory. This causes an error in cases where recompilation only calls Aten-level backends because we don't have a chance to record new real inputs in TorchFX-level backend. This PR always keeps only tensor metadata and non-tensor data on CPU and materialize the tensors when needed (i.e. when recompilation happens and only Aten-level backends are called without real inputs). As we use dummy data to materialize tensors, this solution might still not work but improves the coverage. The new module `InputStorage` keeps tensor metadata and non-tensor data for this purpose and materialize tensors. --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Author
Parents
Loading