Refactor FX codegen into extensible Codegen object (#72566)
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.
```
class Codegen(object):
def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
"""
def generate_output(self, output_args: Argument) -> str:
"""
Given the output arguments, generates the return statement of the FX function.
"""
def process_inputs(self, args: Any) -> Any:
"""
Transforms the inputs so that the graph can take them as arguments, as
non-default codegen may result in the inputs to the function being
different from the inputs to the graph.
If the graph was directly runnable, this invariant should hold true
`f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
"""
def process_outputs(self, outputs: Any) -> Any:
"""
Transforms the outputs of the graph to be identical to the codegen.
See ``process_inputs`` for more details.
"""
def additional_globals(self) -> List[Tuple[str, Any]]:
"""
If your codegen uses extra global values, add them here.
For example, return ['List', typing.List] if you need ``List`` in the global context.
"""
```
So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
class ListCodeGen(CodeGen):
def generate_prologue(self, free_vars, maybe_return_annotation):
lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
{', '.join(free_vars)} = args_list"""
return lst_unpack
def additional_globals(self):
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
return inputs[0]
```
and
```
def f(a, b):
return a + b
nf = fx.symbolic_trace(f)
nf.graph.set_codegen(ListCodeGen())
nf.recompile()
print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
a, b = args_list
add = a + b; a = b = None
return add
```
Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566
Reviewed By: desertfire
Differential Revision: D34160424
Pulled By: Chillee
fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa11cfb8189b114b4ee2de89257bd704a)