[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098