pytorch
71beca48 - [dynamo, logging] Report name of defining class along side function name in Dynamo logs (#110190)

Commit
1 year ago
[dynamo, logging] Report name of defining class along side function name in Dynamo logs (#110190) Implement https://github.com/pytorch/pytorch/issues/109236 Sample code: ```python import torch class AAA: class DUMMY: class DUMMY2: pass def dummy(self): def dummy2(): pass class BBB: @staticmethod def CCC(): class DDD: if True: @staticmethod def EEE(): x = [torch.ones(3, 3) for _ in range(5)] return x return DDD def fn(): return AAA.BBB.CCC().EEE() opt_fn = torch.compile(fn, backend="eager") opt_fn() ``` Logs: ```bash $TORCH_LOGS="trace_source" python playground2.py [2023-09-27 17:38:35,641] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:21 in fn (fn) [2023-09-27 17:38:35,641] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] def fn(): [2023-09-27 17:38:35,642] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:22 in fn (fn) [2023-09-27 17:38:35,642] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] return AAA.BBB.CCC().EEE() [2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:11 in CCC (AAA.BBB) (inline depth: 1) [2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] @staticmethod [2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:13 in CCC (AAA.BBB.CCC.DDD) (inline depth: 1) [2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] class DDD: [2023-09-27 17:38:35,723] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:17 in <listcomp> (AAA.BBB.CCC.DDD.EEE) [2023-09-27 17:38:35,723] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] x = [torch.ones(3, 3) for _ in range(5)] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/110190 Approved by: https://github.com/ezyang, https://github.com/mlazos
Author
Committer
Parents
Loading