[JIT] Add partial evaluation graph stitching logic (#65377)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65377
When we run symbolic shape analysis on
```
conv = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
mod = nn.Sequential(conv1, max_pool)
...
graph(%self : __torch__.torch.nn.modules.container.___torch_mangle_0.Sequential,
%input.1 : Tensor):
%18 : bool = prim::Constant[value=0]()
%30 : int[] = prim::Constant[value=[1, 1]]()
%29 : int[] = prim::Constant[value=[3, 3]]()
%28 : int[] = prim::Constant[value=[2, 2]]()
%6 : int = prim::Constant[value=1]()
%self.0.bias : NoneType = prim::Constant()
%self.0.weight : Double(64, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%input.5 : Tensor(SS(-2), 64, SS(-3), SS(-4)) = aten::conv2d(%input.1, %self.0.weight, %self.0.bias, %28, %29, %30, %6)
%input.9 : Tensor(SS(-2), 64, SS(-5), SS(-6)) = aten::max_pool2d(%input.5, %29, %28, %30, %30, %18)
return (%input.9)
```
we partially evaluate the shape compute graph of `conv2d`, whose output gets passed in and used to partially evaluate the shape compute graph of `max_pool2d`.
The conv2d remaining partially eval'd graph is [here](https://gist.github.com/eellison/0598bd224a422211efa1a45d2b7560b7), and the maxpool2d eval'd graph is [here](https://gist.github.com/eellison/625540b84f650ddbefd3ae5511ab8814). We can take the partially eval'd graphs of a series of operators and stitch them together, which allows us to
a) recover symbolic equivalences by CSE'ing & other optimizations
b) calculate shapes for a whole block of operators just on the input, such as for fusing the whole model to nnc with dynamic shapes and then passing along the computed symbolic shapes. the calculation will also handle error handling.
c) (future-looking) generate inputs on demand for straight-line networks that are composed just of aten operators
The combined graph of the two gives us compute for the unknown symbolic dimensions - `SS(-2), SS(-3), SS(-4), SS(-5), and SS(-6)`.
```
graph(%input.1 : int[]):
%42 : bool = prim::Constant[value=0]() # <string>:152:17
%15 : int = prim::Constant[value=3]()
%input_batch_size_dim.1 : int = prim::Constant[value=0]() # <string>:417:41
%13 : int = prim::Constant[value=1]() # <string>:426:61
%12 : int = prim::Constant[value=4]() # <string>:437:32
%11 : str = prim::Constant[value="AssertionError: "]()
%9 : int = prim::Constant[value=2]()
%8 : int = prim::Constant[value=6]()
%7 : int = prim::Constant[value=7]()
%16 : int = aten::len(%input.1) # <string>:438:17
%17 : bool = aten::eq(%16, %12) # <string>:438:17
= prim::If(%17) # <string>:438:10
block0():
-> ()
block1():
= prim::RaiseException(%11) # <string>:438:10
-> ()
%18 : int = aten::__getitem__(%input.1, %13) # <string>:407:17
%19 : bool = aten::eq(%18, %15) # <string>:407:17
= prim::If(%19) # <string>:407:10
block0():
-> ()
block1():
= prim::RaiseException(%11) # <string>:407:10
-> ()
%20 : int = aten::__getitem__(%input.1, %9) # <string>:411:20
%21 : int = aten::add(%20, %8) # <string>:411:20
%22 : bool = aten::ge(%21, %7) # <string>:411:20
= prim::If(%22) # <string>:411:12
block0():
-> ()
block1():
= prim::RaiseException(%11) # <string>:411:12
-> ()
%23 : int = aten::__getitem__(%input.1, %15) # <string>:411:20
%24 : int = aten::add(%23, %8) # <string>:411:20
%25 : bool = aten::ge(%24, %7) # <string>:411:20
= prim::If(%25) # <string>:411:12
block0():
-> ()
block1():
= prim::RaiseException(%11) # <string>:411:12
-> ()
%26 : int = aten::__getitem__(%input.1, %input_batch_size_dim.1) # <string>:422:29
%27 : int = aten::sub(%20, %13) # <string>:428:32
%28 : int = aten::floordiv(%27, %9) # <string>:428:32
%29 : int = aten::add(%28, %13) # <string>:428:32
%30 : int = aten::sub(%23, %13) # <string>:428:32
%31 : int = aten::floordiv(%30, %9) # <string>:428:32
%32 : int = aten::add(%31, %13) # <string>:428:32
%48 : int = aten::floordiv(%28, %9) # <string>:133:17
%outputSize.2 : int = aten::add(%48, %13) # <string>:136:23
%51 : int = aten::floordiv(%31, %9) # <string>:133:17
%outputSize.1 : int = aten::add(%51, %13) # <string>:136:23
%53 : bool = aten::ne(%29, %input_batch_size_dim.1) # <string>:156:41
%54 : bool = prim::If(%53) # <string>:157:64
block0():
%55 : bool = aten::ne(%32, %input_batch_size_dim.1) # <string>:157:93
-> (%55)
block1():
-> (%42)
= prim::If(%54) # <string>:157:10
block0():
-> ()
block1():
= prim::RaiseException(%11) # <string>:157:10
-> ()
%56 : bool = aten::ge(%outputSize.1, %13) # <string>:160:17
%57 : bool = prim::If(%56) # <string>:160:17
block0():
%58 : bool = aten::ge(%outputSize.2, %13) # <string>:160:38
-> (%58)
block1():
-> (%42)
= prim::If(%57) # <string>:160:10
block0():
-> ()
block1():
= prim::RaiseException(%11) # <string>:160:10
-> ()
return (%26, %29, %32, %outputSize.2, %outputSize.1)
```
This PR runs shape analysis, retains the partially evaluated graphs, and then stitches them together, keeping track of what inputs in the partial eval graph correspond to what inputs in the encompassing graph IR and what outputs correspond to what symbolic shape. Adding NNC ppl as reviewers because it is relevant to dynamic shape fusion.
Question for reviewers : should I make this a separate file ?
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D31732419
Pulled By: eellison
fbshipit-source-id: 883a55cbeef0fd5a6068a779ffa89b6f537245b3