[JIT] Combine concat nodes where possible (#67000)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67000
See the [related issue](https://github.com/pytorch/pytorch/issues/66654) for context.
This new JIT optimization transforms patterns like this:
```
%inputs.1 : Tensor[] = prim::ListConstruct(%a, %b, %c)
%concat.1 : Tensor = aten::cat(%inputs, %dim)
%inputs.2 : Tensor[] = prim::ListConstruct(%x, %concat.1, %y)
%concat.2 : Tensor = aten::cat(%inputs.2, %dim)
```
into this:
```
%inputs.2 : Tensor[] = prim::ListConstruct(%x, %a, %b, %c, %y)
%concat.2 : Tensor = aten::cat(%inputs.2, %dim)
```
(it can do this for chains of `aten::cat` longer than 2 as well)
A few conditions have to hold:
1. The `dim`s have to match.
2. `inputs.1` and `inputs.2` cannot be mutated
Test Plan: `buck test caffe2/test/cpp/jit:jit -- ConcatOpt`
Reviewed By: d1jang
Differential Revision: D31819491
fbshipit-source-id: 9f1a501d52099eb1a630b5dd906df4c38c3817ba