[ONNX] add cast operator after reduce to match desired dtype (#100700)
This PR conditionally inserts a cast operator after a reduction operation to match the specified dtype in the exported ONNX model. The code changes affect **opset9**, and **opset13**.
I understand there's an [automatic upcast to int64](https://github.com/pytorch/pytorch/blob/c91a41fd6827f076224477913b1c09d20a887935/torch/onnx/symbolic_opset9.py#L783) before reduction most likely to prevent overflow so I left that alone and only conditionally add casting back to desired dtype.
## Test int32
```
import torch
import onnx
a = torch.tensor([10, 20, 30, 80], dtype=torch.int32)
def test():
class SumInt32(torch.nn.Module):
def forward(self, a):
return torch.sum(a, dtype=torch.int32)
sumi = SumInt32().eval()
assert sumi(a).dtype == torch.int32
print("Torch model output type matches input type")
torch.onnx.export(sumi, (a), "/tmp/sumi_int32.onnx", opset_version=12)
model = onnx.load("/tmp/sumi_int32.onnx")
assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT32
print("ONNX model output type matches input type")
test()
```
![sumi_int32 onnx](https://user-images.githubusercontent.com/10516699/236499220-59b64821-5807-4f69-b0e2-90ae34280e03.png)
## Test int64
```
import onnx
import torch
a = torch.tensor([10, 20, 30, 80], dtype=torch.int64)
def test():
class SumInt64(torch.nn.Module):
def forward(self, a):
return torch.sum(a, dtype=torch.int64)
sumi = SumInt64().eval()
assert sumi(a).dtype == torch.int64
print("Torch model output type matches input type")
torch.onnx.export(sumi, (a), "/tmp/sumi_int64.onnx", opset_version=12)
model = onnx.load("/tmp/sumi_int64.onnx")
assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT64
print("ONNX model output type matches input type")
test()
```
![sum_int64 onnx](https://user-images.githubusercontent.com/10516699/236422133-15f9cda3-242f-46da-9b23-c2e920f27078.png)
Fixes https://github.com/pytorch/pytorch/issues/100097
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100700
Approved by: https://github.com/thiagocrepaldi