Print full exception info in Graph break log (#119292)
So, this is a little awkward, so I don't mind more thoughts on how best to do this.
Let's suppose that you have a graph break inside of an inlined function call. We are not actually going to print this graph break yet; instead, we are going to restart analysis so that we can run up until the inlined function call. When this happens, the only log message we ever get is the log to `graph_break` (seen here) reporting that a graph break has occurred.
In the current code, we don't print the fully formatted exception if you are only using `graph_breaks` logging. So the exception that induced the graph break has its traceback lost forever. For some classes of errors, esp., guard on data-dependent SymInt, this is quite bad.
With this change, we do print the traceback. On this sample program:
```
import torch
import torch._dynamo.config
torch._dynamo.config.capture_scalar_outputs = True
def g(x, y):
y = x.item()
if y < 3:
return x + 2
else:
return x + 3
@torch.compile()
def f(x, y):
y = y * y
return g(x, y)
f(torch.tensor(4), torch.randn(4))
```
It looks like this:
```
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: Traceback (most recent call last):
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/tensor.py", line 878, in evaluate_expr
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return guard_scalar(self.sym_num)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 414, in guard_scalar
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return guard_bool(a)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 663, in guard_bool
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return a.node.guard_bool("", 0) # NB: uses Python backtrace
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/sym_node.py", line 366, in guard_bool
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return fn(*args, **kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3670, in evaluate_expr
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] concrete_val = self.size_hint(orig_expr)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3403, in size_hint
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] raise self._make_data_dependent_error(result_expr, expr)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: It appears that you're trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is u0 < 3 (unhinted: u0 < 3). For more information, run with TORCH_LOGS="+dynamic".
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] During handling of the above exception, another exception occurred:
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Traceback (most recent call last):
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return inner_fn(self, inst)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] self.call_function(fn, args, {})
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] self.push(fn.call_function(self, args, kwargs))
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/functions.py", line 279, in call_function
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return super().call_function(tx, args, kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/functions.py", line 87, in call_function
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return tx.inline_user_function_return(
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2262, in inline_call
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return cls.inline_call_(parent, func, args, kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2372, in inline_call_
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] tracer.run()
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] and self.step()
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] getattr(self, inst.opname)(inst)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 431, in inner
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] eval_result = value.evaluate_expr(self.output)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/tensor.py", line 880, in evaluate_expr
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] raise UserError( # noqa: TRY200
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] torch._dynamo.exc.UserError: Consider annotating your code using torch._constrain_as_*(). It appears that you're trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is u0 < 3 (unhinted: u0 < 3). For more information, run with TORCH_LOGS="+dynamic".
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] From user code at:
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/b.py", line 16, in f
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return g(x, y)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/b.py", line 8, in g
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] if y < 3:
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
```
The end of the log at restarted computation maybe can be improved too. Right now it looks like this:
```
[2024-02-06 10:32:24,338] [0/0_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 2 [UserFunctionVariable(), LazyVariableTracker(), TensorVariable()]
[2024-02-06 10:32:24,338] [0/0_1] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='Consider annotating your code using torch._constrain_as_*(). It appears that you\'re trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is u0 < 3 (unhinted: u0 < 3). For more information, run with TORCH_LOGS="+dynamic".\n\nFor more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example', user_stack=[<FrameSummary file /data/users/ezyang/b/pytorch/b.py, line 16 in f>, <FrameSummary file /data/users/ezyang/b/pytorch/b.py, line 8 in g>], graph_break=True)
```
An alternative to doing it this way, is I can make symbolic shapes print a warning log when guard on unbacked SymInt itself, so we don't have to worry about Dynamo generating the backtrace well. If, for the most part, the backtrace for other graph breaks is irrelevant, then this would seem to be a more expedient solution.
PTAL and submit your opinions.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119292
Approved by: https://github.com/yanboliang