pytorch
cc7de1df - Add Handling of Cat in Shape Analysis (#65575)

Commit
3 years ago
Add Handling of Cat in Shape Analysis (#65575) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65575 This is needed for lowering an NNC model to mobile. It is also the last class of unhandled ops which NNC fuses, and we need integration this for computing output symbolic shapes. The graph of with two dynamic shape inputs produces: ``` graph(%x.1 : Tensor(SS(-2), 2, 3), %y.1 : Tensor(SS(-3), 2, 3)): %5 : int = prim::Constant[value=0]() %4 : Tensor[] = prim::ListConstruct(%x.1, %y.1) %6 : Tensor(SS(-4), 2, 3) = aten::cat(%4, %5) # /private/home/eellison/pytorch/test/jit/test_symbolic_shape_analysis.py:290:19 return (%6) ``` With a partial eval graph of ``` Done with partial evaluation graph(%129 : int[], %130 : int[], %dim.14 : int): %738 : int = prim::Constant[value=3]() %737 : int = prim::Constant[value=2]() %132 : int = prim::Constant[value=0]() %392 : int = aten::__getitem__(%129, %132) # <string>:339:44 %417 : int = aten::__getitem__(%130, %132) # <string>:339:44 %cat_dim_size.48 : int = aten::add(%392, %417) # <string>:339:29 %result_size.5 : int[] = prim::ListConstruct(%cat_dim_size.48, %737, %738) return (%result_size.5) ``` To handle cat, I essentially make the cat shape op variadic, replacing ``` torch.cat([x, y] ... def cat_shape_op(tensors: List[List[int]], dim: int): ... op(tensors) ``` with ``` def cat_shape_op(x: List[int], y: List[int], dim: int): tensors = [x, y] op(tensors) ``` This reuses the existing input Tensor properties partial evaluation path and avoids having to add special handling to optimize out `len(tensors)` calls in the IR. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D31732416 Pulled By: eellison fbshipit-source-id: 6d93ddf62c34846ec238159f75229632515530b7
Author
Elias Ellison
Parents
Loading