pytorch
3c5ec6af - Partition modules (#98628)

Commit
1 year ago
Partition modules (#98628) Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing. `get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]` Args: * graph: The graph we want to partition * wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear) Returns: * Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type. ``` @dataclass class SourcePartition(): # Nodes in a particular partition nodes: List[Node] # Module type module_type: Type # Nodes in the graph that are needed as inputs to the partition input_nodes: List[Node] = field(default_factory=list) # Nodes in the partition that are being used by nodes outside of the partition output_nodes: List[Node] = field(default_factory=list) # Parameters that are being used params: List[str] = field(default_factory=list) ``` Example: Original: ``` x -> linear -> linear -> relu -> linear ``` Traced graph: ``` .graph(): %arg0 : [#users=1] = placeholder[target=arg0] %_param_constant0 : [#users=1] = get_attr[target=_param_constant0] %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {}) %_param_constant1 : [#users=1] = get_attr[target=_param_constant1] %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {}) %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0] %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {}) %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1] %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {}) %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {}) %_param_constant2 : [#users=1] = get_attr[target=_param_constant2] %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {}) %_param_constant3 : [#users=1] = get_attr[target=_param_constant3] %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {}) return [addmm_default_2] ``` Result of `get_module_partitions`: ``` {<class 'torch.nn.modules.linear.Linear'>: [ ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]), ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]), ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])], <class 'torch.nn.modules.activation.ReLU'>: [ ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]} ``` Also added helper function to check if two module partitions are connected: `check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool` Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628 Approved by: https://github.com/cccclai
Author
Committer
Parents
Loading