pytorch
16827221 - keep output type after calling SubgraphRewriter (#65453)

Commit
4 years ago
keep output type after calling SubgraphRewriter (#65453) Summary: For jit **SubgraphRewriter**, it doesn't keep output type after overwriting the old graph, for example, in profiling mode, the old graph has the old operator's shapes, but after replacing the old operator with a newer operator by applying **SubgraphRewriter**, the tensor shape info was eliminated. The activation is that I want to replace pytorch convolution with a customer's convolution, I first register **aten::_convolution** as a profiler node that can reorder the input and output's shapes, and then using graph rewrite to replace it as **aten::conv2d**, which tensors' shapes info are eliminated. I hope using input size do some pre-progress before replacing **aten::conv2d** with the customer's convolution. Before rewrite: ``` graph(%self.1 : __torch__.MyModule, %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)): %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/ site-packages/torch/nn/modules/conv.py:443:0 %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %4 : NoneType = prim::Constant() %3 : int[] = prim::Constant[value=[1, 1]]() %2 : int[] = prim::Constant[value=[0, 0]]() %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1) %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2 2:0 %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv) %x : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::_convolution(%x.1, %weight, %4, %3, %2, %3, %6, %2, %7, %6, %6, %5, %5), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3. 6/site-packages/torch/nn/modules/conv.py:443:0 %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%x, %z, %7) # jit_test.py: 24:0 return (%16) ``` after rewrite by using **aten::conv2d** ``` graph(%self.1 : __torch__.MyModule, %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)): %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0 %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0 %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0 %4 : NoneType = prim::Constant() %3 : int[] = prim::Constant[value=[1, 1]]() %2 : int[] = prim::Constant[value=[0, 0]]() %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1) %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:22:0 %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv) %18 : Tensor = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7) %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py:24:0 return (%16) ``` expected result after replace **aten::_convolution** with **aten::conv2d**: ``` graph(%self.1 : __torch__.MyModule, %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)): %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/ site-packages/torch/nn/modules/conv.py:443:0 %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %4 : NoneType = prim::Constant() %3 : int[] = prim::Constant[value=[1, 1]]() %2 : int[] = prim::Constant[value=[0, 0]]() %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1) %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2 2:0 %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv) %18 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7) %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py :24:0 return (%16) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/65453 Reviewed By: zdevito Differential Revision: D31162489 Pulled By: ZolotukhinM fbshipit-source-id: 0d1c1d607cb612df47c64f173d9f4c9e8b1d6c49
Author
Parents
Loading