[JIT] Regularize tensorexpr fuser strategy with other fusers (#44972)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44972
Previously, our fusion strategy would be:
- start at the end of the block, find a fusable node
- iteratively try to merge inputs into the fusion group, sorted topologically
This strategy works pretty well, but has the possibility of missing fusion groups. See my attached test case for an example where we wouldn't find all possible fusion groups. bertmaher found an example of a missed fusion groups in one of our rnn examples (jit_premul) that caused a regression from the legacy fuser.
Here, I'm updating our fusion strategy to be the same as our other fusion passes - create_autodiff_subgraphs, and graph_fuser.cpp.
The basic strategy is:
- iterate until you find a fusible node
- try to merge the nodes inputs, whenever a succesful merge occurs restart at the beginning of the nodes inputs
- after you've exhausted a node, continue searching the block for fusion opportunities from the node
- continue doing this on the block until we go through an iteration without an succesful merges
Since we create the fusion groups once, and only re-specialize within the fusion groups, we should be running this very infrequently (only re-triggers when we fail undefinedness specializations). Also bc it's the same algorithm as the existing fuser it is unlikely to cause a regression.
Test Plan: Imported from OSS
Reviewed By: Krovatkin, robieta
Differential Revision: D23821581
Pulled By: eellison
fbshipit-source-id: e513d1ef719120dadb0bfafc7a14f4254cd806ee