pytorch
13941307 - [FX] make ASTReriter patch wrapped functions properly (#62987)

Commit
3 years ago
[FX] make ASTReriter patch wrapped functions properly (#62987) Summary: reference the same global namespace (instead of copying it) in ASTRewriter to patch wrapped functions properly Fixes #{62071} Pull Request resolved: https://github.com/pytorch/pytorch/pull/62987 Test Plan: To test it you may write this snippet and ensure the results are as shown in the comments: ``` import torch import torch.fx torch.fx.wrap def to_be_wrapped(x): return torch.relu(x) class Foo(torch.nn.Module): def forward(self, x): return to_be_wrapped(x) traced = torch.fx.symbolic_trace(Foo()) print(traced.graph) """ graph(): %x : [#users=1] = placeholder[target=x] %to_be_wrapped : [#users=1] = call_function[target=__main__.to_be_wrapped](args = (%x,), kwargs = {}) return to_be_wrapped """ from torch.fx.experimental.rewriter import RewritingTracer rt = RewritingTracer() graph = rt.trace(Foo()) print(graph) """ ### AFTER FIX (CORRECT): graph(): %x : [#users=1] = placeholder[target=x] %to_be_wrapped : [#users=1] = call_function[target=__main__.to_be_wrapped](args = (%x,), kwargs = {}) return to_be_wrapped ### BEFORE FIX (WRONG): graph(): %x : [#users=1] = placeholder[target=x] %relu : [#users=1] = call_function[target=torch.relu](args = (%x,), kwargs = {}) return relu """ ``` Reviewed By: ansley Differential Revision: D30396176 Pulled By: mostafaelhoushi fbshipit-source-id: f61eddf32e9ef42b5f5c3ce21d559945214ee833
Parents
Loading