pytorch
c697eeba - [JIT] Combine concat nodes where possible (#67000)

Commit
4 years ago
[JIT] Combine concat nodes where possible (#67000) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67000 See the [related issue](https://github.com/pytorch/pytorch/issues/66654) for context. This new JIT optimization transforms patterns like this: ``` %inputs.1 : Tensor[] = prim::ListConstruct(%a, %b, %c) %concat.1 : Tensor = aten::cat(%inputs, %dim) %inputs.2 : Tensor[] = prim::ListConstruct(%x, %concat.1, %y) %concat.2 : Tensor = aten::cat(%inputs.2, %dim) ``` into this: ``` %inputs.2 : Tensor[] = prim::ListConstruct(%x, %a, %b, %c, %y) %concat.2 : Tensor = aten::cat(%inputs.2, %dim) ``` (it can do this for chains of `aten::cat` longer than 2 as well) A few conditions have to hold: 1. The `dim`s have to match. 2. `inputs.1` and `inputs.2` cannot be mutated Test Plan: `buck test caffe2/test/cpp/jit:jit -- ConcatOpt` Reviewed By: d1jang Differential Revision: D31819491 fbshipit-source-id: 9f1a501d52099eb1a630b5dd906df4c38c3817ba
Author
Mike Iovine
Parents
Loading