pytorch
5db7db66 - [JIT] Add partial evaluation graph stitching logic (#65377)

Commit
3 years ago
[JIT] Add partial evaluation graph stitching logic (#65377) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65377 When we run symbolic shape analysis on ``` conv = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) mod = nn.Sequential(conv1, max_pool) ... graph(%self : __torch__.torch.nn.modules.container.___torch_mangle_0.Sequential, %input.1 : Tensor): %18 : bool = prim::Constant[value=0]() %30 : int[] = prim::Constant[value=[1, 1]]() %29 : int[] = prim::Constant[value=[3, 3]]() %28 : int[] = prim::Constant[value=[2, 2]]() %6 : int = prim::Constant[value=1]() %self.0.bias : NoneType = prim::Constant() %self.0.weight : Double(64, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]() %input.5 : Tensor(SS(-2), 64, SS(-3), SS(-4)) = aten::conv2d(%input.1, %self.0.weight, %self.0.bias, %28, %29, %30, %6) %input.9 : Tensor(SS(-2), 64, SS(-5), SS(-6)) = aten::max_pool2d(%input.5, %29, %28, %30, %30, %18) return (%input.9) ``` we partially evaluate the shape compute graph of `conv2d`, whose output gets passed in and used to partially evaluate the shape compute graph of `max_pool2d`. The conv2d remaining partially eval'd graph is [here](https://gist.github.com/eellison/0598bd224a422211efa1a45d2b7560b7), and the maxpool2d eval'd graph is [here](https://gist.github.com/eellison/625540b84f650ddbefd3ae5511ab8814). We can take the partially eval'd graphs of a series of operators and stitch them together, which allows us to a) recover symbolic equivalences by CSE'ing & other optimizations b) calculate shapes for a whole block of operators just on the input, such as for fusing the whole model to nnc with dynamic shapes and then passing along the computed symbolic shapes. the calculation will also handle error handling. c) (future-looking) generate inputs on demand for straight-line networks that are composed just of aten operators The combined graph of the two gives us compute for the unknown symbolic dimensions - `SS(-2), SS(-3), SS(-4), SS(-5), and SS(-6)`. ``` graph(%input.1 : int[]): %42 : bool = prim::Constant[value=0]() # <string>:152:17 %15 : int = prim::Constant[value=3]() %input_batch_size_dim.1 : int = prim::Constant[value=0]() # <string>:417:41 %13 : int = prim::Constant[value=1]() # <string>:426:61 %12 : int = prim::Constant[value=4]() # <string>:437:32 %11 : str = prim::Constant[value="AssertionError: "]() %9 : int = prim::Constant[value=2]() %8 : int = prim::Constant[value=6]() %7 : int = prim::Constant[value=7]() %16 : int = aten::len(%input.1) # <string>:438:17 %17 : bool = aten::eq(%16, %12) # <string>:438:17 = prim::If(%17) # <string>:438:10 block0(): -> () block1(): = prim::RaiseException(%11) # <string>:438:10 -> () %18 : int = aten::__getitem__(%input.1, %13) # <string>:407:17 %19 : bool = aten::eq(%18, %15) # <string>:407:17 = prim::If(%19) # <string>:407:10 block0(): -> () block1(): = prim::RaiseException(%11) # <string>:407:10 -> () %20 : int = aten::__getitem__(%input.1, %9) # <string>:411:20 %21 : int = aten::add(%20, %8) # <string>:411:20 %22 : bool = aten::ge(%21, %7) # <string>:411:20 = prim::If(%22) # <string>:411:12 block0(): -> () block1(): = prim::RaiseException(%11) # <string>:411:12 -> () %23 : int = aten::__getitem__(%input.1, %15) # <string>:411:20 %24 : int = aten::add(%23, %8) # <string>:411:20 %25 : bool = aten::ge(%24, %7) # <string>:411:20 = prim::If(%25) # <string>:411:12 block0(): -> () block1(): = prim::RaiseException(%11) # <string>:411:12 -> () %26 : int = aten::__getitem__(%input.1, %input_batch_size_dim.1) # <string>:422:29 %27 : int = aten::sub(%20, %13) # <string>:428:32 %28 : int = aten::floordiv(%27, %9) # <string>:428:32 %29 : int = aten::add(%28, %13) # <string>:428:32 %30 : int = aten::sub(%23, %13) # <string>:428:32 %31 : int = aten::floordiv(%30, %9) # <string>:428:32 %32 : int = aten::add(%31, %13) # <string>:428:32 %48 : int = aten::floordiv(%28, %9) # <string>:133:17 %outputSize.2 : int = aten::add(%48, %13) # <string>:136:23 %51 : int = aten::floordiv(%31, %9) # <string>:133:17 %outputSize.1 : int = aten::add(%51, %13) # <string>:136:23 %53 : bool = aten::ne(%29, %input_batch_size_dim.1) # <string>:156:41 %54 : bool = prim::If(%53) # <string>:157:64 block0(): %55 : bool = aten::ne(%32, %input_batch_size_dim.1) # <string>:157:93 -> (%55) block1(): -> (%42) = prim::If(%54) # <string>:157:10 block0(): -> () block1(): = prim::RaiseException(%11) # <string>:157:10 -> () %56 : bool = aten::ge(%outputSize.1, %13) # <string>:160:17 %57 : bool = prim::If(%56) # <string>:160:17 block0(): %58 : bool = aten::ge(%outputSize.2, %13) # <string>:160:38 -> (%58) block1(): -> (%42) = prim::If(%57) # <string>:160:10 block0(): -> () block1(): = prim::RaiseException(%11) # <string>:160:10 -> () return (%26, %29, %32, %outputSize.2, %outputSize.1) ``` This PR runs shape analysis, retains the partially evaluated graphs, and then stitches them together, keeping track of what inputs in the partial eval graph correspond to what inputs in the encompassing graph IR and what outputs correspond to what symbolic shape. Adding NNC ppl as reviewers because it is relevant to dynamic shape fusion. Question for reviewers : should I make this a separate file ? Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D31732419 Pulled By: eellison fbshipit-source-id: 883a55cbeef0fd5a6068a779ffa89b6f537245b3
Author
Elias Ellison
Parents
Loading