pytorch
fdb9441e - Stop recursion on trivial replacement (#97903)

Commit
1 year ago
Stop recursion on trivial replacement (#97903) Pattern replacement behaves incorrectly when the replacement pattern maps inputs to outputs (such a pattern can be used to replace redundant code). However, current code in `torch.fx.subgraph_rewriter._replace_pattern` causes the list of replacement nodes to include the entire graph before that node, resulting in an exponential slowdown due to recursive calls traversing the entire graph multiple times. The proposed fix is to add a check in `_replace_pattern` to prevent the call to `get_replacement_nodes`: ```python for ret_node in copied_returning_nodes: if ret_node in match.placeholder_nodes: replacement_nodes.append(ret_node) else: get_replacement_nodes(ret_node) ``` Fixes #97817 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97903 Approved by: https://github.com/angelayi
Author
Committer
Parents
Loading