Fuse nodes with sizes (s0*s1*...,) and (s0, s1, s2, ...) (#120077)
Description:
- PR tries to fuse nodes with compatible sizes, for example `node1: (s0, s1, s2)` and `node2: (s0 * s1 * s2)`. On `main` these two nodes can be fused due to different sizes. With this PR we can recompute node2 size, body etc using node1 indexing constraint and thus be able to fuse two nodes.
- this should influence only cpu device
Example:
```python
from unittest.mock import patch
import torch
from torch._inductor.graph import GraphLowering
from torch._inductor import config
# Force multple scheduler nodes creation to fuse them
config.realize_opcount_threshold = 1
@torch.compile(fullgraph=True, dynamic=True)
def fn(x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor:
o1 = x * w1.view(1, 1, 1, -1)
o2 = x * w2.view(1, 1, 1, -1)
output = o1 + o2
return output
in_nodes = []
outputs = []
run_node = GraphLowering.run_node
graph_lowering_obj = None
def run_node_alt(self, n):
global graph_lowering_obj
graph_lowering_obj = self
in_nodes.append(n)
output = run_node(self, n)
outputs.append(output)
return output
x = torch.rand(1, 3, 32, 32)
w1 = torch.randn(32)
w2 = torch.randn(32)
with patch.object(GraphLowering, "run_node", run_node_alt):
fn(x, w1, w2)
print("graph_lowering_obj.buffers:", graph_lowering_obj.buffers)
print("graph_lowering_obj.scheduler:", graph_lowering_obj.scheduler.nodes)
```
Output on `main`:
```
graph_lowering_obj.buffers: [ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
tmp1 = ops.load(arg1_1, i3)
tmp2 = tmp0 * tmp1
return tmp2
,
ranges=[1, s1, s0, s0],
origin_node=mul,
origins={mul}
)), ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
tmp1 = ops.load(arg4_1, i3)
tmp2 = tmp0 * tmp1
return tmp2
,
ranges=[1, s1, s0, s0],
origin_node=mul_1,
origins={mul_1}
)), ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(buf0, i3 + i1 * s0**2 + i2 * s0)
tmp1 = ops.load(buf1, i3 + i1 * s0**2 + i2 * s0)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[1, s1, s0, s0],
origin_node=add,
origins={add}
))]
graph_lowering_obj.scheduler: [FusedSchedulerNode(nodes=buf0_buf1), SchedulerNode(name='buf2')]
```
Output on this PR:
```
graph_lowering_obj.buffers: [ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
tmp1 = ops.load(arg1_1, i3)
tmp2 = tmp0 * tmp1
return tmp2
,
ranges=[1, s1, s0, s0],
origin_node=mul,
origins={mul}
)), ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
tmp1 = ops.load(arg4_1, i3)
tmp2 = tmp0 * tmp1
return tmp2
,
ranges=[1, s1, s0, s0],
origin_node=mul_1,
origins={mul_1}
)), ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
_, i1, i2, i3 = index
tmp0 = ops.load(buf0, i3 + i1 * s0**2 + i2 * s0)
tmp1 = ops.load(buf1, i3 + i1 * s0**2 + i2 * s0)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[1, s1, s0, s0],
origin_node=add,
origins={add}
))]
graph_lowering_obj.scheduler: [FusedSchedulerNode(nodes=buf0_buf1_buf2)]
```
Context:
While working on https://github.com/pytorch/pytorch/pull/120411, upsampling bicubic decomposition, I saw an extra for-loop in C++ generated code summing up two buffers. Exploring the cause, it happend due to buffer number of ops goes beyond `config.realize_opcount_threshold`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120077
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10