[quant][fx] Add root_node_getter in backend_config_dict (#73345)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73345
For complex patterns we need to identify which node is the root, so that we can eliminate all other nodes and only preserve the root,
e.g. (torch.add, MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), we can preserve the torch.nn.Conv2d as root node, and remove other nodes.
Prevoiusly we assumed the root_node of a pattern is the "last node" of the pattern, computed by:
```
def default_root_node_getter(node_pattern):
while not isinstance(node_pattern[-1], Node):
node_pattern = node_pattern[-1]
return node_pattern[-1]
```
This PR enables user configuration to define their own root_node_getter, that means we can define root_node for patterns like:
(torch.add, (torch.nn.ReLU, torch.nn.Conv2d), MatchAllNode)
Test Plan:
python test/test_quantize_fx.py TestFuseFx.test_root_node_getter
Imported from OSS
Reviewed By: VitalyFedyunin
Differential Revision: D34442193
fbshipit-source-id: 2f6da69a5b6527b49710ae32820e8e2915d9af37
(cherry picked from commit 8b49bf0d7d53cdcf2c9f40f8e25bc843e8814026)