pytorch
093495d3 - [fx] prevent implicit submodule inlining when submodule is a GraphModule (#62436)

Commit
3 years ago
[fx] prevent implicit submodule inlining when submodule is a GraphModule (#62436) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62436 ## Problem Given two modules and a tracer that indiscriminately marks all modules as a leaf: ``` class InnerModule(torch.nn.Module): def forward(self, t): return t + t class MyModule(torch.nn.Module): def __init__(self, inner): super().__init__() self.inner = inner def forward(self, t): x = self.inner(t) y = self.inner(t) return x + y class MyTracer(torch.fx.Tracer): def is_leaf_module(self, module, name): return True ``` One might generally expect the following behavior (note call_module nodes): ``` print(">> Outer GraphModule (with inner module as nn.Module):") inner = InnerModule() m = MyModule(inner) gm = torch.fx.GraphModule(m, MyTracer().trace(m)) print(gm.graph.print_tabular()) >> Outer GraphModule (with inner module as nn.Module): opcode name target args kwargs ------------- ------- ----------------------- ---------------- -------- placeholder t t () {} call_module inner inner (t,) {} call_module inner_1 inner (t,) {} call_function add <built-in function add> (inner, inner_1) {} output output output (add,) {} None ``` However, when the inner module is first symbolically traced, the symbolic trace of the outer module ignores `is_leaf_node` entirely, and traces through the whole module (note call_function nodes). ``` print(">> Inner module as GraphModule:") inner = InnerModule() inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner)) print(inner_gm.graph.print_tabular()) print(">> Outer GraphModule (with inner module as GraphModule):") m = MyModule(inner_gm) gm = torch.fx.GraphModule(m, MyTracer().trace(m)) print(gm.graph.print_tabular()) >> Inner module as GraphModule: opcode name target args kwargs ------------- ------ ----------------------- ------ -------- placeholder t t () {} call_function add <built-in function add> (t, t) {} output output output (add,) {} None >> Outer GraphModule (with inner module as GraphModule): opcode name target args kwargs ------------- ------ ----------------------- ------------ -------- placeholder t t () {} call_function add <built-in function add> (t, t) {} call_function add_1 <built-in function add> (t, t) {} call_function add_2 <built-in function add> (add, add_1) {} output output output (add_2,) {} None ``` This is surprising behavior and at first glance violates the tracer's intent. As I understand it, `torch.fx.symbolic_trace.Tracer.trace` intends to patch `torch.nn.Module.__call__` with a `module_call_wrapper()` that records a `call_module` node if the module is a leaf, else executes `torch.fx._symbbolic_trace._orig_module_call = torch.nn.Module.__call__`, which is set a module loading time. **Every submodule should be a leaf, but no `call_module` nodes are created when that submodule is a `GraphModule`. Why?** Upon further inspection, I found: - The constructor for GraphModule includes a path to `GraphModule.recompile()` via the setter for a `fx.Graph`: ``` inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner)) File "/torch/fx/graph_module.py", line 252, in __init__ self.graph = graph File "/torch/nn/modules/module.py", line 1183, in __setattr__ object.__setattr__(self, name, value) File "/torch/fx/graph_module.py", line 277, in graph self.recompile() ``` - `recompile()` wraps the `__call__` method by holding a reference to the `__call__` method at the time of recompilation: ``` cls = type(self) cls_call = cls.__call__ ... def wrapped_call(self, *args, **kwargs): try: return cls_call(self, *args, **kwargs) except Exception as e: ... cls.__call__ = wrapped_call ``` - Recompilation of the inner GraphModule happens on initialization, before creation or tracing of the outer module. Adding some old-fashioned print debug statements gives: ``` Inner Module: _orig_module_call: <function Module._call_impl at 0x7faaebfee8b0> recompile: cls.__call__ now wraps _orig_module_call, <function Module._call_impl at 0x7faaebfee8b0> Outer Module: _orig_module_call: <function Module._call_impl at 0x7faaebfee8b0> tracing: patching method <class 'torch.nn.modules.module.Module'>.__call__ <function Module._call_impl at 0x7faaebfee8b0> with <function Module._call_impl at 0x7fa9d42bce50> outer module MRO before tracing: (0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7faaebfee8b0> (1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0> (2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00> outer module MRO during tracing: (0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7fa9d42bce50> (1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50> (2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00> inner module MRO before tracing: (0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670> (1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7faaebfee8b0> (2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0> (3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00> inner module MRO during tracing: (0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670> (1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7fa9d42bce50> (2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50> (3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00> ``` - The outer module is patched correctly, but the inner module's first element in its MRO is the `wrapped_call` from `recompile` that still invokes `<function Module._call_impl at 0x7faaebfee8b0>` directly. Therefore, no call_module nodes are created. ## In Practice In practice, this behavior affects the ability of `torch.package` to package `GraphModules` whose submodules are `GraphModules`. In our case, the `GraphModule` submodules are not passed through a constructor, but created separately and installed on the root `GraphModule` via `setattr`. This means that prior to packaging, there appear to be no issues with the module, since the root's graph was created before any call_module targets were replaced with `GraphModules`. When unpackaging such a model with `torch.package`, `torch.fx.graph_module._deserialize_graph_module` uses an inline `KeepModules` tracer that sets all submodules to leaves; the unpackaged module is implicitly and surprisingly inlined in the process. ## Potential Solution This behavior was previously not understood by us, and so the current workaround is a gnarly process of wrapping all submodules with a `nn.Module` with a manually installed forward method. Changing `wrapped_call` to return `return super(type(self), self).__call__(*args, **kwargs)` whenever `__call__` is inherited at least appears to solve the issue. Does this seem like an acceptable approach? ## Other Thoughts - Repeated calls to `recompile` create nested `wrapped_calls`, all for the purpose of error handling. This seems probably unnecessary ¯\\_(ツ)\_/¯ - If a root module with a overriden `__call__` method is symbolically traced, it is ignored Test Plan: ``` buck test: ✓ ListingSuccess: caffe2/test:fx - main (12.570) ✓ Pass: caffe2/test:fx - test_tracing_graphmodules_as_leaf_submodules (test_fx.TestFX) (11.982) ``` Reviewed By: ansley Differential Revision: D29997935 fbshipit-source-id: 1988fbb025b14188da26a3e73e94fb789c3c1f74
Author
Parents
Loading