[FX] fuse permute021 linear pass for trt lowering (#66362)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66362
In general we cannot rely on Permute021Linear being kept as is before lowering phase before our transformation could have traced through this module. A acc based fx pass is more reliable to recover the perf.
Test Plan:
```
buck run mode/opt -c python.package_style=inplace -c fbcode.nvcc_arch=a100 //hpc/new/models/ads/benchmarks:ads_dense_benchmark -- over-arch --model-version=23x_3tb --batch-size=2048
OverArch, PyTorch, FP16, BS: 2048, TFLOP/s: 53.22, Time per iter: 14.46ms, QPS: 141629.45
OverArch, TensorRT, FP16, BS: 2048, TFLOP/s: 92.20, Time per iter: 8.35ms, QPS: 245354.15
```
Unittest:
```
buck test mode/dev-nosan caffe2/torch/fb/fx2trt:test_fuse_permute_linear_trt
```
Reviewed By: jianyuh, wushirong, 842974287
Differential Revision: D31525307
fbshipit-source-id: b472a8c277aa4d156d933d6a5abec091133f22c5