[FX] make ASTReriter patch wrapped functions properly (#62987)
Summary:
reference the same global namespace (instead of copying it) in ASTRewriter to patch wrapped functions properly
Fixes #{62071}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62987
Test Plan:
To test it you may write this snippet and ensure the results are as shown in the comments:
```
import torch
import torch.fx
torch.fx.wrap
def to_be_wrapped(x):
return torch.relu(x)
class Foo(torch.nn.Module):
def forward(self, x):
return to_be_wrapped(x)
traced = torch.fx.symbolic_trace(Foo())
print(traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%to_be_wrapped : [#users=1] = call_function[target=__main__.to_be_wrapped](args = (%x,), kwargs = {})
return to_be_wrapped
"""
from torch.fx.experimental.rewriter import RewritingTracer
rt = RewritingTracer()
graph = rt.trace(Foo())
print(graph)
"""
### AFTER FIX (CORRECT):
graph():
%x : [#users=1] = placeholder[target=x]
%to_be_wrapped : [#users=1] = call_function[target=__main__.to_be_wrapped](args = (%x,), kwargs = {})
return to_be_wrapped
### BEFORE FIX (WRONG):
graph():
%x : [#users=1] = placeholder[target=x]
%relu : [#users=1] = call_function[target=torch.relu](args = (%x,), kwargs = {})
return relu
"""
```
Reviewed By: ansley
Differential Revision: D30396176
Pulled By: mostafaelhoushi
fbshipit-source-id: f61eddf32e9ef42b5f5c3ce21d559945214ee833