pytorch
8d363d37 - [FX] Adds PyTree support to FX through `concrete_args` (#55888)

Commit
3 years ago
[FX] Adds PyTree support to FX through `concrete_args` (#55888) Summary: ``` class Foo(nn.Module): def __init__(self): super().__init__() def forward(self, y, x): for k in x: for v in x[k]: v += y return x example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}} new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict) print(new_f.code) new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}}) fx.symbolic_trace(new_f, concrete_args=example_dict) ``` prints out ``` def forward(self, y, x): y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0] add = tree_2 + y add_1 = tree_3 + y add_2 = tree_4 + y; y = None return {'a': [tree_2], 'z': [tree_3, tree_4]} ``` Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code. Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose. Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888 Reviewed By: jamesr66a Differential Revision: D27884694 Pulled By: Chillee fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
Author
Parents
Loading