pytorch
6cfe555f - [ONNX] Apply Common Subexpression Elimination pass to ONNX export (#85665)

Commit
2 years ago
[ONNX] Apply Common Subexpression Elimination pass to ONNX export (#85665) ## Summary Exporting graphs with Autocast may fail due to a limitation on JIT tracer. By disabling Autocast cache, tracer works, but there can be performance hit when there is reuse of weights in convolution, for example By applying CSE, such performance loss can be reverted. ps: As a comment at #84092 mentioned, disabling Autocast cache is an acceptable workaround and used throughout PyTorch code. Fixes #84092 ## Examples of before and after CSE being applied: ### Example: eliminating `%17` and reusing `%16` instead ```python # BEFORE graph(%0 : Float(requires_grad=0, device=cpu)): %3 : Scalar = aten::ScalarImplicit(%0), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %13 : int = prim::Constant[value=3](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %14 : int = prim::Constant[value=4](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %15 : int[] = prim::ListConstruct(%13, %14), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %16 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %17 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %18 : Device = prim::Constant[value="cpu"](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %19 : bool = prim::Constant[value=0](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %20 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::full(%15, %3, %16, %17, %18, %19), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 return (%20) AFTER graph(%0 : Float(requires_grad=0, device=cpu)): %3 : Scalar = aten::ScalarImplicit(%0), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %13 : int = prim::Constant[value=3](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %14 : int = prim::Constant[value=4](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %15 : int[] = prim::ListConstruct(%13, %14), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %16 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %18 : Device = prim::Constant[value="cpu"](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %19 : bool = prim::Constant[value=0](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %20 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::full(%15, %3, %16, %16, %18, %19), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 return (%20) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/85665 Approved by: https://github.com/ngimel, https://github.com/AllenTiTaiWang, https://github.com/BowenBao
Author
Thiago Crepaldi
Committer
Parents
Loading