[torch.fx] Fix replace pattern mechanism (#66442)
Summary:
Fixes #{issue number}
The following code would not return the pattern correctly:
```python
def f(x):
x = torch.sigmoid(x)
x = torch.sigmoid(x)
return torch.sigmoid(x)
def pattern(x):
return torch.sigmoid(x)
def replacement(x):
return torch.exp(x)
def comparison(x):
x = torch.exp(x)
x = torch.exp(x)
return torch.exp(x)
traced = symbolic_trace(f)
comparison_fn = symbolic_trace(comparison)
subgraph_rewriter.replace_pattern(traced, pattern, replacement) # Only one sigmoid gets converted.
```
This PR fixes this by adding a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66442
Reviewed By: ZolotukhinM
Differential Revision: D32238424
Pulled By: ansley
fbshipit-source-id: 386e777174c639baafc166d5ffbc0658a96b1ee9