pytorch
99848c72 - [quant] Add tensor_qparam variant to fake_quantize_per_tensor (#61317)

Commit
3 years ago
[quant] Add tensor_qparam variant to fake_quantize_per_tensor (#61317) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61317 Add an overload to fake_quantize_per_tensor that accepts scale/zero_point as input. The reasons to do this are * required for fused observer + fake_quant operator on GPU where the scale/zero_point will be calculated by the observer on device. Passing tensor inputs enables us to directly access the scale/zero-point value in the cuda kernel to avoid extra copies/malloc * enables us to pass in float as scale dtype and int32 as zero_point dtype (which is consistent with what the quantize call actually uses) https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/affine_quantizer_base.cpp#L52-L53 * overload consistent with `quantizer_per_tensor.tensor_qparams` ghstack-source-id: 133370216 Test Plan: buck test mode/dev-nosan caffe2/test/:quantization -- test_backward_per_tensor_cachemask buck test mode/dev-nosan caffe2/test/:quantization -- test_forward_per_tensor_cachemask Reviewed By: raghuramank100 Differential Revision: D29552727 fbshipit-source-id: cbb9af40fc575ad27a29c646b760d5ee52cc923d
Author
Parents
Loading