pytorch
3e1c8168 - Add pattern to merge/simplify split-cat (#100713)

Commit
2 years ago
Add pattern to merge/simplify split-cat (#100713) Summary: In simple cases, both split and cat node can be removed in a "split->cat" pattern. However, there are various cases where they can't simply be removed and we need to simplify split/ add transforms before cat. Some such cases are: * Split-dim != cat-dim (but equal split) * Final node: cat vs stack * Final node has additional args * Shuffling of args between split/cat * Some final nodes are non-(cat/stack) For more details, please refer to https://docs.google.com/presentation/d/1SxBuY_FZfljSlX6i8slRNgP2CsUCICP0o4qe8cNNX8U/edit#slide=id.g232e9a90f64_0_273 (slides 8-15) Differential Revision: D45452404 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100713 Approved by: https://github.com/jansel
Committer
Parents
Loading