Include fused nodes' debug_str in FusedSchedulerNode::debug_str_extra (#106356)
Currently, there's no way to print the debug information of fused scheduler nodes. I'm adding this to inspect the individual nodes' ir type e.g. ComputedBuffer, but not sure if this would be useful for more use cases
FusedSchedulerNode::debug_str_extra only prints its fused nodes' names
```
# calling .debug_str() on a FusedSchedulerNode
buf0_buf1: FusedSchedulerNode(NoneType)
buf0_buf1.writes = [MemoryDep('buf0', c0, {c0: 10}), MemoryDep('buf1', c0, {c0: 10})]
buf0_buf1.unmet_dependencies = []
buf0_buf1.met_dependencies = [MemoryDep('arg0_1', c0, {c0: 100}), MemoryDep('arg1_1', c0, {c0: 10})]
buf0_buf1.users = None
buf0_buf1.snodes = ['buf0', 'buf1']
```
This PR adds support to print the fused nodes' debug_str
```
buf0_buf1: FusedSchedulerNode(NoneType)
buf0_buf1.writes = [MemoryDep('buf0', c0, {c0: 10}), MemoryDep('buf1', c0, {c0: 10})]
buf0_buf1.unmet_dependencies = []
buf0_buf1.met_dependencies = [MemoryDep('arg0_1', c0, {c0: 100}), MemoryDep('arg1_1', c0, {c0: 10})]
buf0_buf1.users = None
buf0_buf1.snodes[0] =
buf0: SchedulerNode(ComputedBuffer)
buf0.writes = [MemoryDep('buf0', c0, {c0: 10})]
buf0.unmet_dependencies = []
buf0.met_dependencies = [MemoryDep('arg0_1', c0, {c0: 100})]
buf0.users = [NodeUser(node=SchedulerNode(name='buf1'), can_inplace=True)]
buf0.group.device = cuda:0
buf0.group.iteration = (10, 10)
buf0.sizes = ([10], [10])
class buf0_loop_body:
var_ranges = {z0: 10, z1: 10}
index0 = 10*z0 + z1
index1 = z0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg0_1', get_index)
reduction = ops.reduction(torch.float32, torch.float32, 'sum', load)
get_index_1 = self.get_index('index1')
store_reduction = ops.store_reduction('buf0', get_index_1, reduction)
return store_reduction
buf0_buf1.snodes[1] =
buf1: SchedulerNode(ComputedBuffer)
buf1.writes = [MemoryDep('buf1', c0, {c0: 10})]
buf1.unmet_dependencies = [MemoryDep('buf0', c0, {c0: 10})]
buf1.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 10})]
buf1.users = [NodeUser(node=OUTPUT, can_inplace=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (10, 1)
buf1.sizes = ([10], [])
class buf1_loop_body:
var_ranges = {z0: 10}
index0 = z0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg1_1', get_index)
cos = ops.cos(load)
get_index_1 = self.get_index('index0')
load_1 = ops.load('buf0', get_index_1)
add = ops.add(cos, load_1)
get_index_2 = self.get_index('index0')
store = ops.store('buf1', get_index_2, add, None)
return store
```
I'm assuming that FusedSchedulerNode cannot be fused, i.e. can't have FusedSchedulerNode::snodes contain any FusedSchedulerNode.
# Tests
Changes were tested adhoc by printing debug_str in GraphLowering::count_bytes, and running `python3 test/inductor/test_perf.py -k test_fusion_choice3`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106356
Approved by: https://github.com/peterbell10