[ONNX] Apply Common Subexpression Elimination pass to ONNX export (#85665)
## Summary
Exporting graphs with Autocast may fail due to a limitation on JIT tracer. By disabling Autocast cache, tracer works, but there can be performance hit when there is reuse of weights in convolution, for example
By applying CSE, such performance loss can be reverted.
ps: As a comment at #84092 mentioned, disabling Autocast cache is an acceptable workaround and used throughout PyTorch code.
Fixes #84092
## Examples of before and after CSE being applied:
### Example: eliminating `%17` and reusing `%16` instead
```python
# BEFORE
graph(%0 : Float(requires_grad=0, device=cpu)):
%3 : Scalar = aten::ScalarImplicit(%0), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%13 : int = prim::Constant[value=3](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%14 : int = prim::Constant[value=4](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%15 : int[] = prim::ListConstruct(%13, %14), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%16 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%17 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%18 : Device = prim::Constant[value="cpu"](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%19 : bool = prim::Constant[value=0](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%20 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::full(%15, %3, %16, %17, %18, %19), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
return (%20)
AFTER
graph(%0 : Float(requires_grad=0, device=cpu)):
%3 : Scalar = aten::ScalarImplicit(%0), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%13 : int = prim::Constant[value=3](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%14 : int = prim::Constant[value=4](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%15 : int[] = prim::ListConstruct(%13, %14), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%16 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule::
%18 : Device = prim::Constant[value="cpu"](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%19 : bool = prim::Constant[value=0](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
%20 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::full(%15, %3, %16, %16, %18, %19), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0
return (%20)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85665
Approved by: https://github.com/ngimel, https://github.com/AllenTiTaiWang, https://github.com/BowenBao