Pretty print stack trace with gm.print_readable() (#83706)
Precondition: https://github.com/pytorch/torchdynamo/pull/899
Given following function
```
def my_relu(a):
return a.relu()
def func(a, b):
d = torch.square(a + b)
e = my_relu(d)
f = d.sin()
s = torch.stack([e, f])
s = s.sum()
```
Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None
pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
detach_default = torch.ops.aten.detach.default(relu_default)
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None
# No stacktrace found for following nodes
is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]); tangents_1 = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
unbind_int = torch.ops.aten.unbind.int(expand_default); expand_default = None
getitem = unbind_int[0]
getitem_1 = unbind_int[1]; unbind_int = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
cos_default = torch.ops.aten.cos.default(pow_tensor_scalar); pow_tensor_scalar = None
mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default); getitem_1 = cos_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
detach_default_1 = torch.ops.aten.detach.default(detach_default); detach_default = None
threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0); getitem = detach_default_1 = None
# Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default); mul_tensor = threshold_backward_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0); add_tensor = None
mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0); pow_tensor_scalar_1 = None
mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar); add_tensor_1 = mul_scalar = None
sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]); sum_sym_int = None
return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):
# No stacktrace found for following nodes
add = a + b; a = b = None
square = torch.square(add); add = None
relu = square.relu()
sin = square.sin(); square = None
stack = torch.stack([relu, sin]); relu = sin = None
sum_1 = stack.sum(); stack = None
return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
add = a + b; a = b = None
square = torch.square(add); add = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
relu = square.relu()
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
sin = square.sin(); square = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
stack = torch.stack([relu, sin]); relu = sin = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
sum_1 = stack.sum(); stack = None
return sum_1
```
- make_fx without decomposition
```
def forward(self, a_1, b_1):
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
add_tensor = torch.ops.aten.add.Tensor(a_1, b_1); a_1 = b_1 = None
pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2); add_tensor = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
detach_default = torch.ops.aten.detach.default(relu_default)
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
sin_default = torch.ops.aten.sin.default(pow_tensor_scalar); pow_tensor_scalar = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
stack_default = torch.ops.aten.stack.default([relu_default, sin_default]); relu_default = sin_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
sum_default = torch.ops.aten.sum.default(stack_default); stack_default = None
return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]); b_1 = None
add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default); a_1 = broadcast_in_dim_default = None
mul_default = torch.ops.prims.mul.default(add_default, add_default); add_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
le_default = torch.ops.prims.le.default(mul_default, 0.0)
where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default); le_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
sin_default = torch.ops.prims.sin.default(mul_default); mul_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0); where_default = sin_default = None
split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2); cat_default = None
# File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32); split_dim_default = None
sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]); convert_element_type_default = None
return sum_default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83706
Approved by: https://github.com/Chillee, https://github.com/ezyang