Make make_fx cond preserve node meta (#108356)
**Motivation:**
Currently, for the following code that exports cond operator:
```python
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
from torch._export import capture_pre_autograd_graph
example_inputs = (torch.randn(1, 3, 3, 3),)
m = CondBranchClassMethod()
m.eval()
gm = capture_pre_autograd_graph(m, example_inputs)
print(gm)
# source_fn for original cond op, getattr submodule op are all cond op
for n in gm.graph.nodes:
print("n:", n.format_node(), n.meta)
print("\n\n\n")
# source_fn for submodule nodes are all cond op
# Expected: ideally this should be the real ops, e.g. torch.sin, aten.cos, etc
for n in gm.submodule_0.graph.nodes:
print("n:", n.format_node(), n.meta)
```
Output is like below:
```
GraphModule(
(submodule_0): GraphModule()
(submodule_1): GraphModule()
)
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec([arg_0], self._in_spec)
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond = torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1]); submodule_0 = submodule_1 = arg0_1 = None
return pytree.tree_unflatten((cond,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
n: %arg0_1 : [num_users=1] = placeholder[target=arg0_1] {'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None, 'is_torch_exported': True, 'stack_trace': 'NoneType: None\n'}
n: %submodule_0 : [num_users=1] = get_attr[target=submodule_0] {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1}
n: %submodule_1 : [num_users=1] = get_attr[target=submodule_1] {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1}
n: %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (True, %submodule_0, %submodule_1, [%arg0_1]), kwargs = {}) {'stack_trace': 'NoneType: None\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('conditional', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>)], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None, 'is_torch_exported': True}
n: return (cond,) {'stack_trace': 'NoneType: None\n', 'from_node': [('output', 'output')], 'seq_nr': -1, 'is_torch_exported': True, 'val': (FakeTensor(..., size=(1, 3, 3, 3)),), 'tensor_meta': (None,)}
n: %arg0_1 : [num_users=1] = placeholder[target=arg0_1] {'stack_trace': ' File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('arg0_1', 'arg0_1')], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
n: %cos_default : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%arg0_1,), kwargs = {}) {'stack_trace': ' File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': <OpOverload(op='aten.cos', overload='default')>, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('cos', <OpOverload(op='aten.cos', overload='default')>), ('cos_default', <OpOverload(op='aten.cos', overload='default')>)], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
n: return cos_default {'stack_trace': ' File "<ipython-input-9-2a8c7c0498ed>", line 36, in forward\n return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])\n', 'source_fn': ('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), 'original_aten': None, 'from_node': [('cond', <torch._ops.HigherOrderOperator object at 0x7f68ae93efd0>), ('output', 'output')], 'seq_nr': -1, 'val': FakeTensor(..., size=(1, 3, 3, 3)), 'tensor_meta': None}
```
As we can see, the meta of nodes in subgrarphs are overriden with the cond's metat data. This is because the function _set_current_meta is only invoked at the top-level graph module in interpreter. When we're calling into cond and dealing with the submodules here, we didn't set the current_meta to the meta of nodes of subgraph properly.
**Implementation:**
This pr fixes it by: in trace_cond, we optionally use an fx.interpreter to interpret the subgraphs so that the meta data is preserved only when the following conditions are satisfied:
- The subgraphs are graph_module: this is necessary that we use the fx.Interpreter
- The current make_fx has turned preserve_node_meta on (as is the case for capture_pre_autograd_graph).
**Test Plan**
See added tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108356
Approved by: https://github.com/SherlockNoMad