[ONNX] Fix shape inconsistency when exporting scalar log2 (#78701)
This is a simple fix addressing the exportation when the input to `torch.log2` is scalar. `log2(x)` will be exported as `log(x) / log(2)`, which creates a `log` node followed by a `div` node that divides it by a constant. The constant is constructed not as a scalar but as a tensor of shape `[1]`, so a scalar input here will get broadcasted creating the output tensor with shape `[1]`, while originally the torch model's output is a scalar.
```python
import torch
import onnx
import numpy as np
class Model(torch.nn.Module):
def forward(self, x):
return torch.log2(x)
x = torch.tensor(1.) # scalar
model = Model()
torch.onnx.export(model, (x, ), "output.onnx", opset_version=14,
output_names=['o0'], input_names=['i0'])
y_trh = model(x).numpy()
model = onnx.load("output.onnx")
print(model.graph.output[0])
import onnxruntime as ort
sess = ort.InferenceSession(
"output.onnx", providers=['CPUExecutionProvider'])
y_ort = sess.run(['o0'], {'i0': x.numpy()})[0]
assert y_ort.shape == y_trh.shape, 'shape mismatch, ORT is `{}` but PyTorch is `{}`'.format(
y_ort.shape, y_trh.shape)
```
The resulting ONNX model has an output of shape `[1]` and causes shape mismatch between ORT and PyTorch. The output:
```
name: "o0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
Traceback (most recent call last):
File "test.py", line 501, in <module>
y_ort.shape, y_trh.shape)
AssertionError: shape mismatch, ORT is `(1,)` but PyTorch is `()`
```
After the fix, the output becomes:
```
name: "o0"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78701
Approved by: https://github.com/justinchuby, https://github.com/BowenBao