Add pattern to merge/simplify split-cat (#100713)
Summary:
In simple cases, both split and cat node can be removed in a "split->cat" pattern. However, there are various cases where they can't simply be removed and we need to simplify split/ add transforms before cat. Some such cases are:
* Split-dim != cat-dim (but equal split)
* Final node: cat vs stack
* Final node has additional args
* Shuffling of args between split/cat
* Some final nodes are non-(cat/stack)
For more details, please refer to https://docs.google.com/presentation/d/1SxBuY_FZfljSlX6i8slRNgP2CsUCICP0o4qe8cNNX8U/edit#slide=id.g232e9a90f64_0_273 (slides 8-15)
Differential Revision: D45452404
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100713
Approved by: https://github.com/jansel