pytorch
b3e4dab4 - [quant] Input-Weight Equalization - Conv convert support (#61287)

Commit
4 years ago
[quant] Input-Weight Equalization - Conv convert support (#61287) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61287 Modifications to functions during convert() to support equalization. Note that this implementation does not work for connected F.conv2d layers yet. Initial: ``` w | x -> conv -> y ``` After prepare: ``` w | weight_quant_obs | weight_eq_obs | x -> input_quant_obs -> input_eq_obs -> conv -> out_quant_obs -> y ``` After convert: ``` scale, zero_point w (scaled) | | x -> mul -> quantize_per_tensor (scaled) -> quantized::conv -> dequant -> y | eq_scale ``` Test Plan: `python test/test_quantization.py TestEqualizeFx` Initial model: ``` ConvModel( (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False) ) ``` After prepare: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%x,), kwargs = {}) %x_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=x_activation_post_process_0_equalization_process_0](args = (%x_activation_post_process_0,), kwargs = {}) %conv : [#users=1] = call_module[target=conv](args = (%x_activation_post_process_0_equalization_process_0,), kwargs = {}) %conv_activation_post_process_0 : [#users=1] = call_module[target=conv_activation_post_process_0](args = (%conv,), kwargs = {}) return conv_activation_post_process_0 ``` After equalization functions: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%mul,), kwargs = {}) %conv : [#users=1] = call_module[target=conv](args = (%x_activation_post_process_0,), kwargs = {}) %conv_activation_post_process_0 : [#users=1] = call_module[target=conv_activation_post_process_0](args = (%conv,), kwargs = {}) return conv_activation_post_process_0 ``` After convert: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {}) %conv_input_scale_0 : [#users=1] = get_attr[target=conv_input_scale_0] %conv_input_zero_point_0 : [#users=1] = get_attr[target=conv_input_zero_point_0] %quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %conv_input_scale_0, %conv_input_zero_point_0, torch.quint8), kwargs = {}) %conv : [#users=1] = call_module[target=conv](args = (%quantize_per_tensor,), kwargs = {}) %dequantize : [#users=1] = call_method[target=dequantize](args = (%conv,), kwargs = {}) return dequantize ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D29557055 fbshipit-source-id: dc9f44182e31fa362c43ad2dfe224e6f4e4a730e
Author
Parents
Loading