keep output type after calling SubgraphRewriter (#65453)
Summary:
For jit **SubgraphRewriter**, it doesn't keep output type after overwriting the old graph, for example, in profiling mode, the old graph has the old operator's shapes, but after replacing the old operator with a newer operator by applying **SubgraphRewriter**, the tensor shape info was eliminated.
The activation is that I want to replace pytorch convolution with a customer's convolution, I first register **aten::_convolution** as a profiler node that can reorder the input and output's shapes, and then using graph rewrite to replace it as **aten::conv2d**, which tensors' shapes info are eliminated. I hope using input size do some pre-progress before replacing **aten::conv2d** with the customer's convolution.
Before rewrite:
```
graph(%self.1 : __torch__.MyModule,
%x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
%7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/ site-packages/torch/nn/modules/conv.py:443:0
%6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0
%5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0
%4 : NoneType = prim::Constant()
%3 : int[] = prim::Constant[value=[1, 1]]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2 2:0
%weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
%x : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::_convolution(%x.1, %weight, %4, %3, %2, %3, %6, %2, %7, %6, %6, %5, %5), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3. 6/site-packages/torch/nn/modules/conv.py:443:0
%16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%x, %z, %7) # jit_test.py: 24:0
return (%16)
```
after rewrite by using **aten::conv2d**
```
graph(%self.1 : __torch__.MyModule,
%x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
%7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
%6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
%5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
%4 : NoneType = prim::Constant()
%3 : int[] = prim::Constant[value=[1, 1]]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:22:0
%weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
%18 : Tensor = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7)
%16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py:24:0
return (%16)
```
expected result after replace **aten::_convolution** with **aten::conv2d**:
```
graph(%self.1 : __torch__.MyModule,
%x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
%7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/ site-packages/torch/nn/modules/conv.py:443:0
%6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0
%5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0
%4 : NoneType = prim::Constant()
%3 : int[] = prim::Constant[value=[1, 1]]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2 2:0
%weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
%18 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7)
%16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py :24:0
return (%16)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65453
Reviewed By: zdevito
Differential Revision: D31162489
Pulled By: ZolotukhinM
fbshipit-source-id: 0d1c1d607cb612df47c64f173d9f4c9e8b1d6c49