[fx] prevent implicit submodule inlining when submodule is a GraphModule (#62436)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62436
## Problem
Given two modules and a tracer that indiscriminately marks all modules as a leaf:
```
class InnerModule(torch.nn.Module):
def forward(self, t):
return t + t
class MyModule(torch.nn.Module):
def __init__(self, inner):
super().__init__()
self.inner = inner
def forward(self, t):
x = self.inner(t)
y = self.inner(t)
return x + y
class MyTracer(torch.fx.Tracer):
def is_leaf_module(self, module, name):
return True
```
One might generally expect the following behavior (note call_module nodes):
```
print(">> Outer GraphModule (with inner module as nn.Module):")
inner = InnerModule()
m = MyModule(inner)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())
>> Outer GraphModule (with inner module as nn.Module):
opcode name target args kwargs
------------- ------- ----------------------- ---------------- --------
placeholder t t () {}
call_module inner inner (t,) {}
call_module inner_1 inner (t,) {}
call_function add <built-in function add> (inner, inner_1) {}
output output output (add,) {}
None
```
However, when the inner module is first symbolically traced, the symbolic trace of the outer module ignores `is_leaf_node` entirely, and traces through the whole module (note call_function nodes).
```
print(">> Inner module as GraphModule:")
inner = InnerModule()
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
print(inner_gm.graph.print_tabular())
print(">> Outer GraphModule (with inner module as GraphModule):")
m = MyModule(inner_gm)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())
>> Inner module as GraphModule:
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder t t () {}
call_function add <built-in function add> (t, t) {}
output output output (add,) {}
None
>> Outer GraphModule (with inner module as GraphModule):
opcode name target args kwargs
------------- ------ ----------------------- ------------ --------
placeholder t t () {}
call_function add <built-in function add> (t, t) {}
call_function add_1 <built-in function add> (t, t) {}
call_function add_2 <built-in function add> (add, add_1) {}
output output output (add_2,) {}
None
```
This is surprising behavior and at first glance violates the tracer's intent. As I understand it, `torch.fx.symbolic_trace.Tracer.trace` intends to patch `torch.nn.Module.__call__` with a `module_call_wrapper()` that records a `call_module` node if the module is a leaf, else executes `torch.fx._symbbolic_trace._orig_module_call = torch.nn.Module.__call__`, which is set a module loading time.
**Every submodule should be a leaf, but no `call_module` nodes are created when that submodule is a `GraphModule`. Why?**
Upon further inspection, I found:
- The constructor for GraphModule includes a path to `GraphModule.recompile()` via the setter for a `fx.Graph`:
```
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
File "/torch/fx/graph_module.py", line 252, in __init__
self.graph = graph
File "/torch/nn/modules/module.py", line 1183, in __setattr__
object.__setattr__(self, name, value)
File "/torch/fx/graph_module.py", line 277, in graph
self.recompile()
```
- `recompile()` wraps the `__call__` method by holding a reference to the `__call__` method at the time of recompilation:
```
cls = type(self)
cls_call = cls.__call__
...
def wrapped_call(self, *args, **kwargs):
try:
return cls_call(self, *args, **kwargs)
except Exception as e:
...
cls.__call__ = wrapped_call
```
- Recompilation of the inner GraphModule happens on initialization, before creation or tracing of the outer module. Adding some old-fashioned print debug statements gives:
```
Inner Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
recompile: cls.__call__ now wraps _orig_module_call, <function Module._call_impl at 0x7faaebfee8b0>
Outer Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
tracing: patching method <class 'torch.nn.modules.module.Module'>.__call__ <function Module._call_impl at 0x7faaebfee8b0> with <function Module._call_impl at 0x7fa9d42bce50>
outer module MRO before tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
outer module MRO during tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
inner module MRO before tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
inner module MRO during tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
```
- The outer module is patched correctly, but the inner module's first element in its MRO is the `wrapped_call` from `recompile` that still invokes `<function Module._call_impl at 0x7faaebfee8b0>` directly. Therefore, no call_module nodes are created.
## In Practice
In practice, this behavior affects the ability of `torch.package` to package `GraphModules` whose submodules are `GraphModules`. In our case, the `GraphModule` submodules are not passed through a constructor, but created separately and installed on the root `GraphModule` via `setattr`. This means that prior to packaging, there appear to be no issues with the module, since the root's graph was created before any call_module targets were replaced with `GraphModules`.
When unpackaging such a model with `torch.package`, `torch.fx.graph_module._deserialize_graph_module` uses an inline `KeepModules` tracer that sets all submodules to leaves; the unpackaged module is implicitly and surprisingly inlined in the process.
## Potential Solution
This behavior was previously not understood by us, and so the current workaround is a gnarly process of wrapping all submodules with a `nn.Module` with a manually installed forward method.
Changing `wrapped_call` to return `return super(type(self), self).__call__(*args, **kwargs)` whenever `__call__` is inherited at least appears to solve the issue. Does this seem like an acceptable approach?
## Other Thoughts
- Repeated calls to `recompile` create nested `wrapped_calls`, all for the purpose of error handling. This seems probably unnecessary ¯\\_(ツ)\_/¯
- If a root module with a overriden `__call__` method is symbolically traced, it is ignored
Test Plan:
```
buck test:
✓ ListingSuccess: caffe2/test:fx - main (12.570)
✓ Pass: caffe2/test:fx - test_tracing_graphmodules_as_leaf_submodules (test_fx.TestFX) (11.982)
```
Reviewed By: ansley
Differential Revision: D29997935
fbshipit-source-id: 1988fbb025b14188da26a3e73e94fb789c3c1f74