Remove getitem special handling in the partitioner (#87073)
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.
Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx
def func(x):
xx = torch.ops.nvprims.add(x, 1)
var, mean = torch.ops.nvprims.var_mean(x, correction=0)
var_cos = torch.ops.nvprims.cos(var)
mean_sin = torch.ops.nvprims.sin(mean)
return torch.ops.nvprims.add(var_cos, mean_sin)
a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()
supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode name target args kwargs
------------- --------- --------------------------- ---------------- -----------------
placeholder x_1 x_1 () {}
call_function add nvprims.add.default (x_1, 1) {}
call_function var_mean nvprims.var_mean.main (x_1, [0, 1, 2]) {'correction': 0}
call_function getitem <built-in function getitem> (var_mean, 0) {}
call_function getitem_1 <built-in function getitem> (var_mean, 1) {}
call_function cos nvprims.cos.default (getitem,) {}
call_function sin nvprims.sin.default (getitem_1,) {}
call_function add_1 nvprims.add.default (cos, sin) {}
output output output (add_1,) {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode name target args kwargs
------------- --------- --------------------------- ---------------------- --------
placeholder x_1 x_1 () {}
call_module fused_1 fused_1 (x_1,) {}
call_function getitem_2 <built-in function getitem> (fused_1, 0) {}
call_function getitem_3 <built-in function getitem> (fused_1, 1) {}
call_module fused_0 fused_0 (getitem_2, getitem_3) {}
output output output (fused_0,) {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode name target args kwargs
----------- ------- -------- ---------- --------
placeholder x_1 x_1 () {}
call_module fused_0 fused_0 (x_1,) {}
output output output (fused_0,) {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad