Simplify cat fusion (#15633)
Summary:
That makes that definition of a "fusable node" much simpler,
as we don't need to keep considering whether something has to be an
"exit node" at every step. The fuser now tries to maximize the
pointwise fusions first, and proceeds to prepending chunks and appending
concats only once a fix point is reached.
This patch not only makes the fuser much simpler to reason about,
making it siginifcantly easier to implement features like SumToSize
fusion, to improve performance of derivative graphs.
cc zou3519 mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15633
Differential Revision: D13575306
Pulled By: zou3519
fbshipit-source-id: 0c55ea61d65d1f1ed3d75a8e1e83bc85a83f3aff