pytorch
fdd28399 - Replace unsqueeze transform with stack (#101766)

Commit
1 year ago
Replace unsqueeze transform with stack (#101766) As part of split-cat transforms, we needed to unsqueeze additional inputs (not coming from split) but going to the cat/stack nodes. However, this leads to patterns like: ``` split -> unsqueeze -> cat ``` when there are multiple splits going into cat. An alternative is to use stack rather than unsqueeze, leading to patterns like: ``` split -> stack -> cat ``` This is much better, as repeated applications of the same pattern will further simplify "split->stack", which is not trivial in case of "split->unsqueeze->cat". Another nice side-effect is lesser number of nodes in the graph overall. Differential Revision: [D45952452](https://our.internmc.facebook.com/intern/diff/D45952452/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/101766 Approved by: https://github.com/jansel
Committer
Parents
Loading