pytorch
72f446b9 - Remove getitem special handling in the partitioner (#87073)

Commit
2 years ago
Remove getitem special handling in the partitioner (#87073) This special handling of getitem unnecessary splits fusions at functions with tuple outputs. Example script: ```py import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport from torch.fx.experimental.proxy_tensor import make_fx def func(x): xx = torch.ops.nvprims.add(x, 1) var, mean = torch.ops.nvprims.var_mean(x, correction=0) var_cos = torch.ops.nvprims.cos(var) mean_sin = torch.ops.nvprims.sin(mean) return torch.ops.nvprims.add(var_cos, mean_sin) a = torch.randn(5, 3, 3, device="cuda") gm = make_fx(func)(a) gm.graph.print_tabular() supported_ops = NvfuserPrimOperatorSupport() partitioner = CapabilityBasedPartitioner( gm, supported_ops, allows_single_node_partition=False ) partitions = partitioner.propose_partitions() print(partitions) partitioned_graph = partitioner.fuse_partitions(partitions) partitioned_graph.graph.print_tabular() ``` Output on master: ```py opcode name target args kwargs ------------- --------- --------------------------- ---------------- ----------------- placeholder x_1 x_1 () {} call_function add nvprims.add.default (x_1, 1) {} call_function var_mean nvprims.var_mean.main (x_1, [0, 1, 2]) {'correction': 0} call_function getitem <built-in function getitem> (var_mean, 0) {} call_function getitem_1 <built-in function getitem> (var_mean, 1) {} call_function cos nvprims.cos.default (getitem,) {} call_function sin nvprims.sin.default (getitem_1,) {} call_function add_1 nvprims.add.default (cos, sin) {} output output output (add_1,) {} [{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}] opcode name target args kwargs ------------- --------- --------------------------- ---------------------- -------- placeholder x_1 x_1 () {} call_module fused_1 fused_1 (x_1,) {} call_function getitem_2 <built-in function getitem> (fused_1, 0) {} call_function getitem_3 <built-in function getitem> (fused_1, 1) {} call_module fused_0 fused_0 (getitem_2, getitem_3) {} output output output (fused_0,) {} ``` Output with this PR: ``` [{var_mean, add_1, cos, sin, add, getitem_1, getitem}] opcode name target args kwargs ----------- ------- -------- ---------- -------- placeholder x_1 x_1 () {} call_module fused_0 fused_0 (x_1,) {} output output output (fused_0,) {} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/87073 Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
Author
Committer
Parents
Loading