[torch.fx] add code-gen customizability and support for setting breakpoint in code-gen'd forward() call (#67139)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67139
This diff enables setting breakpoint in the graph module's generated python code. See test plan for usage.
In order to support this functionality, and other similar functionalities to customize the generated code, a code transformer functionality is added to `fx.Graph`. This allows flexible customization of `fx.Graph`'s code gen behavior, in composable and functional ways. See test plan for its usage.
Test Plan:
### Use of `fx.experimental.debug.set_trace`
```
In [2]: from torch.fx.experimental.debug import set_trace
In [3]: set_trace(ttop)
Out[3]:
top(
(a): Sub()
)
In [4]: ttop(1)
> /data/users/kefeilu/fbsource33/fbcode/buck-out/dev/gen/caffe2/torch/fb/fx2trt/<eval_with_key>.10(6)forward()
(Pdb) l
1
2
3
4 def forward(self, x):
5 import pdb; pdb.set_trace()
6 -> a = self.a(x); x = None
7 getitem = a[0]
8 getitem_1 = a[0]; a = None
9 add = getitem + getitem_1; getitem = getitem_1 = None
10 return add
11
(Pdb)
```
### Use of `on_generate_code`
```
In [1]: def insert_pdb(body):
...: return ['import pdb; pdb.set_trace()\n', *body]
...:
In [8]: type(ttop)
Out[8]: torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl
In [10]: with ttop.graph.on_generate_code(lambda _: insert_pdb):
...: ttop.recompile()
...: print(f"== _on_generate_code should not be None: { ttop.graph._on_generate_code }")
...: print(ttop.code)
...:
== _on_generate_code should not be None: <function insert_pdb at 0x7fc9895ddd30>
def forward(self, x):
import pdb; pdb.set_trace()
a = self.a(x); x = None
getitem = a[0]
getitem_1 = a[0]; a = None
add = getitem + getitem_1; getitem = getitem_1 = None
return add
In [11]: ttop.graph._on_generate_code # restored to None
In [12]: ttop(1) # this should drop into pdb
> /data/users/kefeilu/fbsource33/fbcode/buck-out/dev/gen/caffe2/torch/fb/fx2trt/<eval_with_key>.6(6)forward()
(Pdb) l
1
2
3
4 def forward(self, x):
5 import pdb; pdb.set_trace()
6 -> a = self.a(x); x = None
7 getitem = a[0]
8 getitem_1 = a[0]; a = None
9 add = getitem + getitem_1; getitem = getitem_1 = None
10 return add
11
```
Reviewed By: jamesr66a
Differential Revision: D30736160
fbshipit-source-id: 9646867aae0461b5131dfd4ba9ee77a8c2ea9c93