SubgraphMatcher: matching modules support. (#25075)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25075
This change adds a special behavior to subgraph matcher to allow it to
match calls to modules. Namely, when a node in the pattern graph has a
'match::module' type, it is considered 'match' only when the
corresponding node in the target graph is a 'prim::GetAttr' obtaining a
submodule which type matches the type specified in 'name' attribute of
the 'match::module' node.
Currently when comparing the expected module type we check if the string
specified in 'name' prefixes qualified name of the module GetAttr
returns. In future when qualified name format is better defined we will
probably change it for the exact comparison.
Why do we want this? In some cases we would like to perform fusion on a
module level rather than on a graph-level. A popular example of such
fusion would be Conv-BN. It is inpractical to match batchnorm on
graph-evel because that would mean we woudl need to specify its full
and exact implementation in the pattern graph. If we match on the
CallMethod level, however, the problem becomes trivial.
The feature added in this PR allows to detect patterns with 'CallMethod'
nodes, which in its turn allows us to use subgraph rewriter to replace
such patterns with some node (or nodes). I expect that a usual approach
would be to use subgraph-rewriter to replace all matches with some
artificial node and then in additional pass replace such nodes with
calls to another module or something else. It is not possible at the
moment to use subgraph-rewriter to add a call to a method of a new
module, because it can not add a new submodule, but we probably would
add a higher level API to do that.
Test Plan: Imported from OSS
Differential Revision: D16978652
Pulled By: ZolotukhinM
fbshipit-source-id: 37307a5ec65cf4618ad8eb595ef5f8ae656e2713
Author
Mikhail Zolotukhin