Add Initial NNC Dynamic Shapes Flow (#66136)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66136
FOR REVIEWERS: this is ready to review, test failures comes from somewhere else in stack..
Takes in a TensorExprGraph of static shapes and generalizes the input shapes
to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
dimensions with the same value will be bucketed to the same symbolic shape.
E.g. `Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)`
From there, runs symbolic shape inference on the graph, and creates a
versioning if in the graph with prim::TensorExprDynamicGuard checking if
the inputs at runtime match the Generalized Symbolic Shapes that are inputs
to the TE Kernel. The computate to calculate all symbolic dimensions is
inlined in to the if block with the TE Kernel. All Sym Dim Value* are
appended to the end of the TE Kernel Graph/Node inputs, and the Node is
augmented with a integer list attr `symbolic_shape_inputs` that gives the
mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR
examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in
`test_shape_analysis` Returns True on Success, False on Failure, can fail if
shape propagation fails to propagate # of dims or if complete shapes on
inputs not set.
Example transformation
```
graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
%y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
%z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
%3 : Tensor = prim::TensorExprGroup_0(%x_inp, %y_inp, %z_inp)
return ()
with prim::TensorExprGroup_0 = graph(%x.1 : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
%y.1 : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
%z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
%3 : int = prim::Constant[value=0]()
%4 : Tensor = aten::tanh(%x.1)
%5 : Tensor = aten::erf(%4)
%6 : Tensor = aten::relu(%y.1)
%7 : Tensor[] = prim::ListConstruct(%5, %6)
%8 : Tensor = aten::cat(%7, %3)
%9 : Tensor = aten::hardswish(%8)
%10 : Tensor = aten::mul(%9, %z)
return (%9)
```
->
```
graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
%y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
%z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
%4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp)
%5 : Tensor = prim::If(%4)
block0():
%15 : int[] = aten::size(%x_inp)
%16 : int[] = aten::size(%y_inp)
%17 : int = prim::Constant[value=1]()
%18 : int = prim::Constant[value=0]()
%elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10
%elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10
%elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10
%cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29
%3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
-> (%3)
block1():
%14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
-> (%14)
return ()
with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
%y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
%z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
%SS_5 : int,
%SS_4 : int,
%SS_3 : int,
%SS_2 : int):
%3 : int = prim::Constant[value=0]()
%4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1)
%5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4)
%6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1)
%7 : Tensor[] = prim::ListConstruct(%5, %6)
%8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3)
%9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8)
%10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z)
return (%9)
```
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D31732414
Pulled By: eellison
fbshipit-source-id: 290a94a667c20467717202a43c60e4f9ca4c00e2