[fx2trt] fuse permute + matmul using a pass instead of hardcoding it as a leaf module (#65482)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65482
Currently we hardcoded permute + bmm in a module and tagged it as a leaf module during tracing. This diff introduces a pass to fuse permute + matmul to a single node.
TODO:
For fusion transformation like this kind, they would actually share many similar code like finding the fusion pattern, replacing original nodes with fused node. Current fx subgraph rewriter allows us to specify patterns that we want to replace but we would need to extend it to allow specify constraint on nodes' kwargs.
Reviewed By: yinghai
Differential Revision: D31022055
fbshipit-source-id: 13d1f18d79b09d371897ecde840f582ccaf5713a