pytorch
d635d0f8 - Refactor FX codegen into extensible Codegen object (#72566)

Commit
3 years ago
Refactor FX codegen into extensible Codegen object (#72566) Summary: The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it. ``` class Codegen(object): def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str: """ Given the free variables and a return annotation, generates the beginning of the FX function. By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'` """ def generate_output(self, output_args: Argument) -> str: """ Given the output arguments, generates the return statement of the FX function. """ def process_inputs(self, args: Any) -> Any: """ Transforms the inputs so that the graph can take them as arguments, as non-default codegen may result in the inputs to the function being different from the inputs to the graph. If the graph was directly runnable, this invariant should hold true `f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)` """ def process_outputs(self, outputs: Any) -> Any: """ Transforms the outputs of the graph to be identical to the codegen. See ``process_inputs`` for more details. """ def additional_globals(self) -> List[Tuple[str, Any]]: """ If your codegen uses extra global values, add them here. For example, return ['List', typing.List] if you need ``List`` in the global context. """ ``` So, for example, the `ListCodeGen` we want for AOTAutograd looks like this ``` class ListCodeGen(CodeGen): def generate_prologue(self, free_vars, maybe_return_annotation): lst_unpack = f""" def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: {', '.join(free_vars)} = args_list""" return lst_unpack def additional_globals(self): return [('List', typing.List)] def process_inputs(self, *inputs): assert(len(inputs) == 1) return inputs[0] ``` and ``` def f(a, b): return a + b nf = fx.symbolic_trace(f) nf.graph.set_codegen(ListCodeGen()) nf.recompile() print(nf.code) ``` would result in ``` def forward(self, args_list: List[torch.Tensor]): a, b = args_list add = a + b; a = b = None return add ``` Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine? Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566 Reviewed By: desertfire Differential Revision: D34160424 Pulled By: Chillee fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375 (cherry picked from commit 13cd12eaa11cfb8189b114b4ee2de89257bd704a)
Author
Committer
Parents
Loading