[ONNX] Make Non-Float Op Exportation Compatible to Avoid Invalid ONNX Models
There are a few ONNX operators do not support non-float (e.g., integer) inputs at early versions. For example, Clip supports non-float types until [opset 12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#type-constraints-280), that said older versions like [opset 6](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#type-constraints-107) cannot deal with integer types.
I initially find such a bug in Clip (https://github.com/pytorch/pytorch/pull/70584), but later found more:
1. Clip < 12;
2. Min/Max < 12;
3. ReLU < 14;
4. Pad < 11;
In PyTorch, if we export Max-11 with integer inputs, actually the exportation will succeed; however, fail when imported by other frameworks like ONNXRuntime.
```python
import torch
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
return torch.max(x, x + 1)
net = Net()
onnx_model = 'test.onnx'
torch.onnx.export(net, (torch.zeros((3, 3), dtype=torch.int32),),
onnx_model, verbose=True, opset_version=11)
```
This is an unexpected behavior as we want to ensure that every model exported by PyTorch is valid (https://github.com/pytorch/pytorch/pull/70584#issuecomment-1020636579). Theoretically, we can simply forbid such cases (e.g., `Clip<int>` < 12, `ReLU<int>` < 14). But actually we can enhance the compatibility and flexibility of PyTorch by simply casting inputs of those operators into float tensors, which allows the float operator functions, and then casting it back to original types.
This PR implements the second approach to achieve better compatibility in PyTorch.
@garymm @thiagocrepaldi
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72401
Approved by: https://github.com/garymm, https://github.com/thiagocrepaldi