[dynamo] fine-grained bytecode-source attribution in python 3.11 (#104676)
Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.
Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond
def h(x):
return x * 5
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f(pred, x):
x = h(
h(h(x))
)
x = x[1:][:2]
torch._dynamo.graph_break()
x = cond(pred, true_fn, false_fn, [x])
opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```
Output:
```
$ TORCH_LOGS="trace_call" python playground9.py
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
x = h(
~^
h(h(x))
^^^^^^^
)
^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
return x * 2
~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
return x * 3
~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104676
Approved by: https://github.com/ezyang