pytorch
2c905f21 - Extend Pattern Matcher to allow handling split-cat style patterns (#97726)

Commit
1 year ago
Extend Pattern Matcher to allow handling split-cat style patterns (#97726) Summary: This diff extends pattern matcher, by adding a few features which allows it to handle split-getitem-cat style patterns. 3 problems I encountered were: 1. In the handler, I only need one Arg() (the one which is the first input to split). None of the other args are relevant to replacement graph. So, we add a new Ignored() pattern to have ignored args 2. The pattern matching was visiting the split node again and again during the DFS. By propogating the patterns with _users>1 or Any into the child MatchContext, we avoid this problem. 3. To avoid the unbundling issue, I switched to using KeywordArg() instead of Arg() - as for this pattern, we need a flat list of Arg() in the end Example pattern: https://www.internalfb.com/intern/anp/view/?id=3325856 ``` pass_patterns.append(defaultdict(list)) register_replacement_pattern( CallFunction( aten.cat, ListOf( CallFunction(operator.getitem, CallFunction(aten.split_with_sizes, KeywordArg("input_"), Ignored(), Ignored(), _users=Any), Ignored() ),), Ignored() ), pass_number=3 ) def split_cat_replace(input_): return input_ ``` Test Plan: https://www.internalfb.com/intern/anp/view/?kernel=default&id=3317105 Reviewed By: jansel Differential Revision: D44282499 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97726 Approved by: https://github.com/jansel
Committer
Parents
Loading