[ONNX] Updating input node removal in ONNX function_substitution pass. (#42146)
Summary:
ONNX pass `torch._C._jit_pass_onnx_function_substitution(graph)` inlines the function with the compiled torch graph. But while it removes all connections with the compiled function node (e.g. see below - `%6 : Function = prim::Constant[name="f"]()`), it does not remove the function node itself. For example, if the input graph is:
```
graph(%0 : Long(requires_grad=0, device=cpu),
%1 : Long(requires_grad=0, device=cpu)):
%6 : Function = prim::Constant[name="f"]()
%7 : Tensor = prim::CallFunction(%6, %0, %1)
return (%7)
```
The output graph is:
```
graph(%0 : Long(requires_grad=0, device=cpu),
%1 : Long(requires_grad=0, device=cpu)):
%6 : Function = prim::Constant[name="f"]()
%8 : int = prim::Constant[value=1]()
%z.1 : Tensor = aten::sub(%0, %1, %8) # test/onnx/test_utility_funs.py:790:20
%10 : Tensor = aten::add(%0, %z.1, %8) # test/onnx/test_utility_funs.py:791:23
return (%10)
```
Note that the `%6 : Function = prim::Constant[name="f"]()` has not been removed (though it is not being used).
This PR updates the pass to remove the function node completely. The updated graph looks as follows:
```
graph(%0 : Long(requires_grad=0, device=cpu),
%1 : Long(requires_grad=0, device=cpu)):
%8 : int = prim::Constant[value=1]()
%z.1 : Tensor = aten::sub(%0, %1, %8) # test/onnx/test_utility_funs.py:790:20
%10 : Tensor = aten::add(%0, %z.1, %8) # test/onnx/test_utility_funs.py:791:23
return (%10)
```
A test point has also been added for this scenario.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42146
Reviewed By: VitalyFedyunin
Differential Revision: D22845314
Pulled By: bzinodev
fbshipit-source-id: 81fb351f0a36f47204e5327b60b84d7a91d3bcd9