pytorch
95fced44 - Pretty dataclass dynamo explain (#102869)

Commit
1 year ago
Pretty dataclass dynamo explain (#102869) Also thinking out loud: maybe we only print graph break reasons? And for the rest we have a verbose print which prints everything? TODO: some tests are failing based on what they expect a guard string to look like, easy to fix i'll do it early next week # After ``` (sourcetorch) ubuntu@ip-172-31-1-136:~/test$ python pretty.py BREAK Graph Count: 2 Graph Break Count: 1 Op Count: 2 Break Reasons: Break Reason 1: Reason: call_function BuiltinVariable(print) [ConstantVariable(str)] {} User Stack: <FrameSummary file /home/ubuntu/test/pretty.py, line 6 in fn> Ops per Graph: Ops 1: <built-in function add> Ops 2: <built-in function add> Out Guards: Guard 1: Name: '' Source: global Create Function: GRAD_MODE Guard Types: ['GRAD_MODE'] Code List: ['___is_grad_enabled()'] Object Weakref: None Guarded Class Weakref: None Guard 2: Name: '' Source: global Create Function: DEFAULT_DEVICE Guard Types: ['DEFAULT_DEVICE'] Code List: ['utils_device.CURRENT_DEVICE == None'] Object Weakref: None Guarded Class Weakref: None Guard 3: Name: "G['print']" Source: global Create Function: BUILTIN_MATCH Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 4: Name: '' Source: global Create Function: DETERMINISTIC_ALGORITHMS Guard Types: ['DETERMINISTIC_ALGORITHMS'] Code List: ['not ___are_deterministic_algorithms_enabled()'] Object Weakref: None Guarded Class Weakref: None Guard 5: Name: "L['x']" Source: local Create Function: TENSOR_MATCH Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 6: Name: '' Source: global Create Function: GRAD_MODE Guard Types: ['GRAD_MODE'] Code List: ['___is_grad_enabled()'] Object Weakref: None Guarded Class Weakref: None Guard 7: Name: '' Source: global Create Function: DEFAULT_DEVICE Guard Types: ['DEFAULT_DEVICE'] Code List: ['utils_device.CURRENT_DEVICE == None'] Object Weakref: None Guarded Class Weakref: None Guard 8: Name: '' Source: global Create Function: DETERMINISTIC_ALGORITHMS Guard Types: ['DETERMINISTIC_ALGORITHMS'] Code List: ['not ___are_deterministic_algorithms_enabled()'] Object Weakref: None Guarded Class Weakref: None Guard 9: Name: "L['x']" Source: local Create Function: TENSOR_MATCH Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Compile Times: TorchDynamo compilation metrics: Function Runtimes (s) ------------------------------ -------------- _compile 0.0164, 0.0035 OutputGraph.call_user_compiler 0.0000, 0.0000 ``` ## Before ``` ('Dynamo produced 2 graphs with 1 graph break and 2 ops', [{Guard(name='print', source=<GuardSource.GLOBAL: 1>, create_fn=<function GuardBuilder.BUILTIN_MATCH at 0x7f92ea5009d0>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None), Guard(name='x', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7f92ea501000>, is_volatile=False, guard_types=['TENSOR_MATCH'], code_list=None, obj_weakref=<weakref at 0x7f9224d28f40; dead>, guarded_class_weakref=<weakref at 0x7f92d81734c0; to 'torch._C._TensorMeta' at 0x540b610 (Tensor)>)}, {Guard(name='x', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7f92ea501000>, is_volatile=False, guard_types=['TENSOR_MATCH'], code_list=None, obj_weakref=<weakref at 0x7f9224d5e700; dead>, guarded_class_weakref=<weakref at 0x7f92d81734c0; to 'torch._C._TensorMeta' at 0x540b610 (Tensor)>)}], [GraphModule(), GraphModule()], [[<built-in function add>], [<built-in function add>]], [GraphCompileReason(reason='call_function BuiltinVariable(print) [ConstantVariable(str)] {}', user_stack=[<FrameSummary file <ipython-input-1-9e2ddb639697>, line 6 in fn>]), GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file <ipython-input-1-9e2ddb639697>, line 8 in <graph break in fn>>])], 'Dynamo produced 2 graphs with 1 graph break and 2 ops\n Break reasons: \n\n1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}\n File "<ipython-input-1-9e2ddb639697>", line 6, in fn\n print("BREAK")\n \n2. return_value\n File "<ipython-input-1-9e2ddb639697>", line 8, in <graph break in fn>\n return x\n \nTorchDynamo compilation metrics:\nFunction Runtimes (s)\n------------------------------ --------------\n_compile 0.0418, 0.0084\nOutputGraph.call_user_compiler 0.0001, 0.0001') ``` ## Program ```python import torch import torch._dynamo def fn(x): x = x + 1 print("BREAK") x = x + 1 return x out = torch._dynamo.explain(fn, torch.randn(10)) print(out) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/102869 Approved by: https://github.com/voznesenskym
Author
Committer
Parents
Loading