pytorch
5dd288eb - [JIT] Regularize tensorexpr fuser strategy with other fusers (#44972)

Commit
4 years ago
[JIT] Regularize tensorexpr fuser strategy with other fusers (#44972) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44972 Previously, our fusion strategy would be: - start at the end of the block, find a fusable node - iteratively try to merge inputs into the fusion group, sorted topologically This strategy works pretty well, but has the possibility of missing fusion groups. See my attached test case for an example where we wouldn't find all possible fusion groups. bertmaher found an example of a missed fusion groups in one of our rnn examples (jit_premul) that caused a regression from the legacy fuser. Here, I'm updating our fusion strategy to be the same as our other fusion passes - create_autodiff_subgraphs, and graph_fuser.cpp. The basic strategy is: - iterate until you find a fusible node - try to merge the nodes inputs, whenever a succesful merge occurs restart at the beginning of the nodes inputs - after you've exhausted a node, continue searching the block for fusion opportunities from the node - continue doing this on the block until we go through an iteration without an succesful merges Since we create the fusion groups once, and only re-specialize within the fusion groups, we should be running this very infrequently (only re-triggers when we fail undefinedness specializations). Also bc it's the same algorithm as the existing fuser it is unlikely to cause a regression. Test Plan: Imported from OSS Reviewed By: Krovatkin, robieta Differential Revision: D23821581 Pulled By: eellison fbshipit-source-id: e513d1ef719120dadb0bfafc7a14f4254cd806ee
Author
Elias Ellison
Parents
Loading