pytorch
5cfca552 - [JIT] clear GraphFunction.optimized_graphs_ after freezing a module (#68316)

Commit
3 years ago
[JIT] clear GraphFunction.optimized_graphs_ after freezing a module (#68316) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68316 Consider the following: ``` class Mod(nn.Module): def __init__(self, val): super().__init__() self.param = nn.Parameter(val) def forward(self, x): # this method will change during freezing return x + self.param torch.jit.export def make_prediction(self, x): y = x + x return self.forward(y) param = torch.rand([2, 2]) unscripted_mod = Mod(param) mod = torch.jit.script(unscripted_mod) mod.eval() mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"])` ``` During freezing the following will occur: 1. do some pre-freezing, including inlining; in particular, forward will be inlined into make_prediction. During inlining, forward.optimized_graph() is called, and the result is cached 2. freeze some methods. While freezing forward, the graph associated with the function will get updated. The cached optimized_graphs_ are not updated. Previously, a call to `mod.forward(x)` would return an exectutor that would run on the old cached optimized_graph(). This would mean that the freezing optimizations would not apply, and potentially that the execution would fail because of parameters removed from the module. This change clears the optimized_graphs_ cache after running freezing to prevent executing an old version of the graph. Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D32410862 Pulled By: davidberard98 fbshipit-source-id: dd8bfe86ec2898b7c72813ab32c08f25c38e4cea
Author
Parents
Loading