pytorch
0751a41a - [quant] Input-Weight Equalization - ConvReLU support (#61350)

Commit
4 years ago
[quant] Input-Weight Equalization - ConvReLU support (#61350) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61350 Applied changes in convert to allow for ConvReLU2d layers Initial Model: `x -> conv1 -> relu` After fusion: `x -> convRelu2d` After prepare: `x -> input_quant_obs -> input_eq_obs1 -> convRelu2d -> output_quant_obs1` After equalization functions: `x -> mul -> input_quant_obs (scaled) -> convRelu2d -> output_quant_obs` After convert: `x -> mul -> quantize_per_tensor -> quantized::convRelu2d -> dequantize` Test Plan: `python test/test_quantization.py TestEqualizeFx` Initial Model: ``` ConvReluModel( (fc): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1)) (relu): ReLU() ) ``` After prepare: ``` GraphModule( (x_activation_post_process_0): MinMaxObserver(min_val=5.960464477539063e-08, max_val=0.9999999403953552) (x_activation_post_process_0_equalization_process_0): _InputEqualizationObserver( (input_obs): PerChannelMinMaxObserver(min_val=tensor([1.1921e-07, 3.3379e-06, 5.9605e-08]), max_val=tensor([1.0000, 1.0000, 1.0000])) ) (fc): ConvReLU2d( (0): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1)) (1): ReLU() ) (fc_activation_post_process_0): MinMaxObserver(min_val=0.0, max_val=1.2341605424880981) ) graph(): %x : [#users=1] = placeholder[target=x] %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%x,), kwargs = {}) %x_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=x_activation_post_process_0_equalization_process_0](args = (%x_activation_post_process_0,), kwargs = {}) %fc : [#users=1] = call_module[target=fc](args = (%x_activation_post_process_0_equalization_process_0,), kwargs = {}) %fc_activation_post_process_0 : [#users=1] = call_module[target=fc_activation_post_process_0](args = (%fc,), kwargs = {}) return fc_activation_post_process_0 ``` After equalization functions: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%mul,), kwargs = {}) %fc : [#users=1] = call_module[target=fc](args = (%x_activation_post_process_0,), kwargs = {}) %fc_activation_post_process_0 : [#users=1] = call_module[target=fc_activation_post_process_0](args = (%fc,), kwargs = {}) return fc_activation_post_process_0 ``` After convert: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {}) %fc_input_scale_0 : [#users=1] = get_attr[target=fc_input_scale_0] %fc_input_zero_point_0 : [#users=1] = get_attr[target=fc_input_zero_point_0] %quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %fc_input_scale_0, %fc_input_zero_point_0, torch.quint8), kwargs = {}) %fc : [#users=1] = call_module[target=fc](args = (%quantize_per_tensor,), kwargs = {}) %dequantize : [#users=1] = call_method[target=dequantize](args = (%fc,), kwargs = {}) return dequantize ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D29638275 fbshipit-source-id: 40d4666a4451e132612ea38fdfeaaec177a1defb
Author
Parents
Loading