pytorch
9b4dc56c - [ONNX] Fix quantization outputs' dtype (#79690)

Commit
3 years ago
[ONNX] Fix quantization outputs' dtype (#79690) Part of #79263 Previously, all quantized PyTorch tensors are all casted to the dtypes which comply with ONNX's definition, i.e. `scale` is casted to `double`, and `zero_point` is casted to `int64`. These casts lead to inconsistent dtypes when comparing PyTorch's outputs and ONNX runtime's outputs. Now, `cast_onnx_accepted` argument is added to `unpack_quantized_tensor` function. When making example inputs for ONNX, we cast them to the ONNX compliant dtypes; otherwise, they are casted to PyTorch default types for quantization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79690 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
Author
Committer
Parents
Loading