pytorch
0196b984 - Add Handling of Cat in Shape Analysis (#65575)

Commit
3 years ago
Add Handling of Cat in Shape Analysis (#65575) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65575 This is needed for lowering an NNC model to mobile. It is also the last class of unhandled ops which NNC fuses, and we need integration this for computing output symbolic shapes. The graph of with two dynamic shape inputs produces: ``` graph(%x.1 : Tensor(SS(-2), 2, 3), %y.1 : Tensor(SS(-3), 2, 3)): %5 : int = prim::Constant[value=0]() %4 : Tensor[] = prim::ListConstruct(%x.1, %y.1) %6 : Tensor(SS(-4), 2, 3) = aten::cat(%4, %5) # /private/home/eellison/pytorch/test/jit/test_symbolic_shape_analysis.py:290:19 return (%6) ``` With a partial eval graph of ``` Done with partial evaluation graph(%129 : int[], %130 : int[], %dim.14 : int): %738 : int = prim::Constant[value=3]() %737 : int = prim::Constant[value=2]() %132 : int = prim::Constant[value=0]() %392 : int = aten::__getitem__(%129, %132) # <string>:339:44 %417 : int = aten::__getitem__(%130, %132) # <string>:339:44 %cat_dim_size.48 : int = aten::add(%392, %417) # <string>:339:29 %result_size.5 : int[] = prim::ListConstruct(%cat_dim_size.48, %737, %738) return (%result_size.5) ``` To handle cat, I essentially make the cat shape op variadic, replacing ``` torch.cat([x, y] ... def cat_shape_op(tensors: List[List[int]], dim: int): ... op(tensors) ``` with ``` def cat_shape_op(x: List[int], y: List[int], dim: int): tensors = [x, y] op(tensors) ``` This reuses the existing input Tensor properties partial evaluation path and avoids having to add special handling to optimize out `len(tensors)` calls in the IR. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D31797471 Pulled By: eellison fbshipit-source-id: 62c794533d5fabfd3fad056d7e5fe3e8781b22c5
Author
Elias Ellison
Parents
Loading