[JIT] Remove profile nodes before BatchMM. (#43961)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43961
Currently we're removing prim::profile nodes and embed the type info
directly in the IR right before the fuser, because it is difficult to
fuse in a presence of prim::profile nodes. It turns out that BatchMM has
a similar problem: it doesn't work when there are prim::profile nodes in
the graph. These two passes run next to each other, so we could simply
remove prim::profile nodes slightly earlier: before the BatchMM pass.
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D23453266
Pulled By: ZolotukhinM
fbshipit-source-id: 92cb50863962109b3c0e0112e56c1f2cb7467ff1
Author
Mikhail Zolotukhin