pytorch
0f05e398 - [ONNX] Fix shape inconsistency when exporting scalar log2 (#78701)

Commit
2 years ago
[ONNX] Fix shape inconsistency when exporting scalar log2 (#78701) This is a simple fix addressing the exportation when the input to `torch.log2` is scalar. `log2(x)` will be exported as `log(x) / log(2)`, which creates a `log` node followed by a `div` node that divides it by a constant. The constant is constructed not as a scalar but as a tensor of shape `[1]`, so a scalar input here will get broadcasted creating the output tensor with shape `[1]`, while originally the torch model's output is a scalar. ```python import torch import onnx import numpy as np class Model(torch.nn.Module): def forward(self, x): return torch.log2(x) x = torch.tensor(1.) # scalar model = Model() torch.onnx.export(model, (x, ), "output.onnx", opset_version=14, output_names=['o0'], input_names=['i0']) y_trh = model(x).numpy() model = onnx.load("output.onnx") print(model.graph.output[0]) import onnxruntime as ort sess = ort.InferenceSession( "output.onnx", providers=['CPUExecutionProvider']) y_ort = sess.run(['o0'], {'i0': x.numpy()})[0] assert y_ort.shape == y_trh.shape, 'shape mismatch, ORT is `{}` but PyTorch is `{}`'.format( y_ort.shape, y_trh.shape) ``` The resulting ONNX model has an output of shape `[1]` and causes shape mismatch between ORT and PyTorch. The output: ``` name: "o0" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } } } } Traceback (most recent call last): File "test.py", line 501, in <module> y_ort.shape, y_trh.shape) AssertionError: shape mismatch, ORT is `(1,)` but PyTorch is `()` ``` After the fix, the output becomes: ``` name: "o0" type { tensor_type { elem_type: 1 shape { } } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78701 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
Author
Committer
Parents
Loading