[fx] Add constant folding pass (#48443)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48443
Add a constant folding pass in FX:
- Iterate over an input graph and tag what nodes are fully constant, i.e. either `get_attr` nodes, or nodes with all inputs that are either `get_attr` or constant
- Use `model_transform.split_by_tags()` to split the graph into two
- Look for the `output` node in the constant graph to get names of attrs that will be folded
- Iterate over the non-constant graph and replace placeholders that are using the same name as the attrs with a `get_attr` as well as a dummy attr on the module
- Return these two graphs in a new `FoldedGraphModule`, which is a normal GraphModule but also stores the constant graph on the side along with a `run_folding()` method that will run const folding and update the dummy parameters with the actual folded parameters
Test Plan: Added a couple tests
Reviewed By: 842974287
Differential Revision: D25033996
fbshipit-source-id: 589c036751ea91bb8155d9be98af7dbc0552ea19