pytorch
38ed3985 - [fx] Add constant folding pass (#48443)

Commit
4 years ago
[fx] Add constant folding pass (#48443) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48443 Add a constant folding pass in FX: - Iterate over an input graph and tag what nodes are fully constant, i.e. either `get_attr` nodes, or nodes with all inputs that are either `get_attr` or constant - Use `model_transform.split_by_tags()` to split the graph into two - Look for the `output` node in the constant graph to get names of attrs that will be folded - Iterate over the non-constant graph and replace placeholders that are using the same name as the attrs with a `get_attr` as well as a dummy attr on the module - Return these two graphs in a new `FoldedGraphModule`, which is a normal GraphModule but also stores the constant graph on the side along with a `run_folding()` method that will run const folding and update the dummy parameters with the actual folded parameters Test Plan: Added a couple tests Reviewed By: 842974287 Differential Revision: D25033996 fbshipit-source-id: 589c036751ea91bb8155d9be98af7dbc0552ea19
Author
Parents
Loading