pytorch
87f08d71 - torch.save/load torch.compiled models (#97565)

Commit
1 year ago
torch.save/load torch.compiled models (#97565) Opening this so I can discuss with @albanD I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")` Right now I'm trying to extend this to work for nn.modules more generally TODO: Failing tests * [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested * [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py` * [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch` * [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang * [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures * [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1` * [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine * [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing * [x] Tried out to instead override `__call__`, python doesnt like this though https://github.com/pytorch/pytorch/pull/97565#issuecomment-1524570439 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97565 Approved by: https://github.com/aaronenyeshi, https://github.com/albanD
Author
Committer
Parents
Loading