pytorch
3362c1d2 - [ONNX] add cast operator after reduce to match desired dtype (#100700)

Commit
1 year ago
[ONNX] add cast operator after reduce to match desired dtype (#100700) This PR conditionally inserts a cast operator after a reduction operation to match the specified dtype in the exported ONNX model. The code changes affect **opset9**, and **opset13**. I understand there's an [automatic upcast to int64](https://github.com/pytorch/pytorch/blob/c91a41fd6827f076224477913b1c09d20a887935/torch/onnx/symbolic_opset9.py#L783) before reduction most likely to prevent overflow so I left that alone and only conditionally add casting back to desired dtype. ## Test int32 ``` import torch import onnx a = torch.tensor([10, 20, 30, 80], dtype=torch.int32) def test(): class SumInt32(torch.nn.Module): def forward(self, a): return torch.sum(a, dtype=torch.int32) sumi = SumInt32().eval() assert sumi(a).dtype == torch.int32 print("Torch model output type matches input type") torch.onnx.export(sumi, (a), "/tmp/sumi_int32.onnx", opset_version=12) model = onnx.load("/tmp/sumi_int32.onnx") assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT32 print("ONNX model output type matches input type") test() ``` ![sumi_int32 onnx](https://user-images.githubusercontent.com/10516699/236499220-59b64821-5807-4f69-b0e2-90ae34280e03.png) ## Test int64 ``` import onnx import torch a = torch.tensor([10, 20, 30, 80], dtype=torch.int64) def test(): class SumInt64(torch.nn.Module): def forward(self, a): return torch.sum(a, dtype=torch.int64) sumi = SumInt64().eval() assert sumi(a).dtype == torch.int64 print("Torch model output type matches input type") torch.onnx.export(sumi, (a), "/tmp/sumi_int64.onnx", opset_version=12) model = onnx.load("/tmp/sumi_int64.onnx") assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT64 print("ONNX model output type matches input type") test() ``` ![sum_int64 onnx](https://user-images.githubusercontent.com/10516699/236422133-15f9cda3-242f-46da-9b23-c2e920f27078.png) Fixes https://github.com/pytorch/pytorch/issues/100097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100700 Approved by: https://github.com/thiagocrepaldi
Author
Committer
Parents
Loading