pytorch
69dcbc02 - [Dynamo]Expose bytecode hooks and add example usage for decompilation in docs (#110714)

Commit
1 year ago
[Dynamo]Expose bytecode hooks and add example usage for decompilation in docs (#110714) Dynamo dynamically translate bytecode of python functions, which is powerful but with difficult-to-understand bytecode. Most users cannot understand python bytecode. Although a general purpose way to decompile python bytecode into source code is very difficult, I find that this work can be greatly simplified since Dynamo already cleans up the code: the bytecode generated by Dynamo is a reduced subset of well-behaved python bytecode. I created a tiny decompiler for pytorch 2.0, named `depyf`: https://github.com/youkaichao/depyf . There are several takeaways: - **It supports pyton 3.7 - 3.11 (both inclusive), the same python versions supported by pytorch.** Since the main usage of this library is to understand pytorch 2.0, I plan to keep pace with pytorch. If pytorch supports a new python version, I can add support for that. (Actually, the core code is just about 1k lines. Adding support for new versions of python bytecode can be done in just several days.) - **I have tested the correctness of decompiled source code in torchbench.** I capture the modified bytecode generated by Dynamo, decompile it into source code, and then compile it into new bytecode, replace the Dynamo generated bytecode with new bytecode. And **it passed all the accuracy tests for timm models**. For huggingface models, the situation is more complicated: all failed cases are caused by the compile step: some functions use the `__class__` as closure variables, but decompiler can only get the code object, so it has no way to figure out the `__class__` , leading to a name error when compiling the decompiled code. That said, it passed the rest tests without the `__class__` issue. Please see the log file https://cloud.tsinghua.edu.cn/f/685e4af8d930499baa7c/?dl=1 and https://cloud.tsinghua.edu.cn/f/cab89500e15e4b62890b/?dl=1 for details. With the above efforts, I think it would be great to add an additional logging option in Dynamo: we can try to decompile the generated bytecode into source code, so that users can have a rough idea of what the modified bytecode does. It does not affect the workflow of Dynamo, but just adds more debug information. An example code from the [doc](https://pytorch.org/docs/main/torch.compiler_deepdive.html): ```python from typing import List import torch from torch import _dynamo as torchdynamo def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable @torchdynamo.optimize(my_compiler) def toy_example(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b for _ in range(100): toy_example(torch.randn(10), torch.randn(10)) ``` Run with `export TORCH_LOGS="+dynamo,guards,bytecode"`. Bytecode logging: ``` [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE toy_example /Users/youkaichao/DeepLearning/depyf/ykc_test.py line 8 [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 0 LOAD_FAST 0 (a) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 2 LOAD_GLOBAL 0 (torch) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 4 LOAD_METHOD 1 (abs) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 6 LOAD_FAST 0 (a) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 CALL_METHOD 1 [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 LOAD_CONST 1 (1) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 BINARY_ADD [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 BINARY_TRUE_DIVIDE [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 16 STORE_FAST 2 (x) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 11 18 LOAD_FAST 1 (b) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 20 LOAD_METHOD 2 (sum) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 22 CALL_METHOD 0 [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 24 STORE_FAST 3 (__temp_2) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 26 LOAD_FAST 3 (__temp_2) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 28 LOAD_CONST 2 (0) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 30 COMPARE_OP 0 (<) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 32 POP_JUMP_IF_FALSE 21 (to 42) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 13 34 LOAD_FAST 1 (b) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 36 LOAD_CONST 3 (-1) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 38 BINARY_MULTIPLY [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 40 STORE_FAST 1 (b) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 >> 42 LOAD_FAST 2 (x) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 44 LOAD_FAST 1 (b) [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 46 BINARY_MULTIPLY [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 48 RETURN_VALUE [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,929] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] MODIFIED BYTECODE toy_example /Users/youkaichao/DeepLearning/depyf/ykc_test.py line 8 [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 0 LOAD_GLOBAL 3 (__compiled_fn_0) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 2 LOAD_FAST 0 (a) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 4 LOAD_FAST 1 (b) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 6 CALL_FUNCTION 2 [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 UNPACK_SEQUENCE 2 [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 STORE_FAST 2 (x) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 POP_JUMP_IF_FALSE 12 (to 24) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 LOAD_GLOBAL 4 (__resume_at_34_1) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 16 LOAD_FAST 1 (b) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 18 LOAD_FAST 2 (x) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 20 CALL_FUNCTION 2 [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 22 RETURN_VALUE [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] >> 24 LOAD_GLOBAL 5 (__resume_at_42_2) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 26 LOAD_FAST 1 (b) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 28 LOAD_FAST 2 (x) [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 30 CALL_FUNCTION 2 [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 32 RETURN_VALUE [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 23:56:44,930] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] ``` New output with this PR: ``` [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] possible source code: [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] def toy_example(a, b): [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] __temp_1 = __compiled_fn_0(a, b) [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] x = __temp_1[0] [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] if __temp_1[1]: [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] return __resume_at_34_1(b, x) [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] return __resume_at_42_2(b, x) [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,535] [0/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues. ``` The rest two log (please pay attention to the output `possible source code:`): ``` [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE <resume in toy_example> /workspace/youkaichao/code/pytorch/ykc.py line 12 [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 0 JUMP_ABSOLUTE 22 (to 44) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 2 LOAD_FAST 2 (a) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 4 LOAD_GLOBAL 0 (torch) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 6 LOAD_ATTR 1 (abs) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 LOAD_FAST 2 (a) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 CALL_FUNCTION 1 [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 LOAD_CONST 1 (1) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 BINARY_ADD [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 16 BINARY_TRUE_DIVIDE [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 18 STORE_FAST 1 (x) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 20 LOAD_FAST 0 (b) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 22 LOAD_ATTR 2 (sum) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 24 CALL_FUNCTION 0 [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 26 STORE_FAST 3 (__temp_2) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 28 LOAD_FAST 3 (__temp_2) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 30 LOAD_CONST 2 (0) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 32 COMPARE_OP 0 (<) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 34 POP_JUMP_IF_FALSE 22 (to 44) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 36 LOAD_FAST 0 (b) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 38 LOAD_CONST 3 (-1) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 40 BINARY_MULTIPLY [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 42 STORE_FAST 0 (b) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 >> 44 LOAD_FAST 1 (x) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 46 LOAD_FAST 0 (b) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 48 BINARY_MULTIPLY [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 50 RETURN_VALUE [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] MODIFIED BYTECODE <resume in toy_example> /workspace/youkaichao/code/pytorch/ykc.py line 12 [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 0 LOAD_GLOBAL 3 (__compiled_fn_3) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 2 LOAD_FAST 0 (b) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 4 LOAD_FAST 1 (x) [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 6 CALL_FUNCTION 2 [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 UNPACK_SEQUENCE 1 [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 RETURN_VALUE [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,566] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,567] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] possible source code: [2023-10-06 16:25:21,567] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] def <resume in toy_example>(b, x): [2023-10-06 16:25:21,567] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] return __compiled_fn_3(b, x)[0] [2023-10-06 16:25:21,567] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,567] [1/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues. ``` ``` [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE <resume in toy_example> /workspace/youkaichao/code/pytorch/ykc.py line 12 [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 0 JUMP_ABSOLUTE 18 (to 36) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 2 LOAD_FAST 2 (a) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 4 LOAD_GLOBAL 0 (torch) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 6 LOAD_ATTR 1 (abs) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 LOAD_FAST 2 (a) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 CALL_FUNCTION 1 [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 LOAD_CONST 1 (1) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 BINARY_ADD [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 16 BINARY_TRUE_DIVIDE [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 18 STORE_FAST 1 (x) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 20 LOAD_FAST 0 (b) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 22 LOAD_ATTR 2 (sum) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 24 CALL_FUNCTION 0 [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 26 STORE_FAST 3 (__temp_2) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 28 LOAD_FAST 3 (__temp_2) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 30 LOAD_CONST 2 (0) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 32 COMPARE_OP 0 (<) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 34 POP_JUMP_IF_FALSE 22 (to 44) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 13 >> 36 LOAD_FAST 0 (b) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 38 LOAD_CONST 3 (-1) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 40 BINARY_MULTIPLY [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 42 STORE_FAST 0 (b) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 14 >> 44 LOAD_FAST 1 (x) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 46 LOAD_FAST 0 (b) [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 48 BINARY_MULTIPLY [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 50 RETURN_VALUE [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,579] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] MODIFIED BYTECODE <resume in toy_example> /workspace/youkaichao/code/pytorch/ykc.py line 12 [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 12 0 LOAD_GLOBAL 3 (__compiled_fn_4) [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 2 LOAD_FAST 0 (b) [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 4 LOAD_FAST 1 (x) [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 6 CALL_FUNCTION 2 [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 8 UNPACK_SEQUENCE 1 [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] 10 RETURN_VALUE [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] possible source code: [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] def <resume in toy_example>(b, x): [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] return __compiled_fn_4(b, x)[0] [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] [2023-10-06 16:25:21,580] [2/0] torch._dynamo.convert_frame.__bytecode: [DEBUG] If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/110714 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading