pytorch
5c38c4cf - Improve symbolic shapes guard logging (#98941)

Commit
1 year ago
Improve symbolic shapes guard logging (#98941) Billing of changes: * Get rid of `print_guards`; instead, you control this with `TORCH_LOGS=torch.fx.experimental.symbolic_shapes`, debug logging toggles stack traces * Don't incorrectly report the tracing context frame when we're compiling; we just don't have this info anymore! (TODO: use the saved frames instead). This is via a new TracingContext.clear_frame context manager * Add TracingContext.extract_stack() which gives you the tracing context stack. * Add ShapeEnvLoggingAdapter to report which ShapeEnv any given operation is from (this is helpful for debugging situations when there are too many ShapeEnvs floating around) * Tweak create_symbol log message to also report Source * Add a debug log whenever duck sizing occurs * Report an excerpt of both the user and system backtrace whenever a guard is added in INFO mode. I found this is a good balance of "where did the guard come from" without full backtrace verbosity. Example log output with the new output: ``` [2023-04-12 08:25:49,003] torch.fx.experimental.symbolic_shapes: [INFO] 0: create_env [2023-04-12 08:25:49,021] torch.fx.experimental.symbolic_shapes: [INFO] 0: create_symbol s0 = 32 for L['x'].size()[0] [2023-04-12 08:25:50,154] torch.fx.experimental.symbolic_shapes: [INFO] 0: evaluate_expr s0 < 128 [guard added] at w.py:11 in forward2 (_dynamo/variables/tensor.py:476 in evaluate_expr) [2023-04-12 08:25:52,057] torch.fx.experimental.symbolic_shapes: [INFO] 0: evaluate_expr Eq(Mod(s0, 16), 0) [guard added] (_inductor/codegen/triton.py:77 in is_aligned) ``` from running ``` import torch import torch._dynamo def f(x, y): return x + y def forward(x, y): return forward2(x, y) def forward2(x, y): if x.size(0) < 128: x = x * 2 else: x = x * 3 r = f(x, y) r = r * y return r def woof(): fn_compiled = torch.compile(forward, dynamic=True) x = torch.randn(32, device='cuda') y = torch.randn(32, device='cuda') print(fn_compiled(x, y)) woof() ``` (To induce the Triton guard, I synthetically reverted https://github.com/pytorch/pytorch/pull/98471) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/98941 Approved by: https://github.com/wconstab
Author
Committer
Parents
Loading