pytorch
ddf2ce03 - [quant] Input-Weight Equalization - support for connected linear layers (#60034)

Commit
3 years ago
[quant] Input-Weight Equalization - support for connected linear layers (#60034) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60034 Added support for equalizing models with connected linear layers. To account for connected linear layers, we will additionally multiply the previous weight values (row-wise) by the next equalization scale, and remove the input equalization observer between the two linear layers. We also want to scale the bias by the next equalization scale. The math is shown here: https://fb.quip.com/fK8rA9aRM4ca . Original Model: `x -> linear1 -> linear2` After `prepare_fx`: `x -> InpEqObs -> InpQuantObs -> linear1 -> OutQuantObs -> InpEqObs -> linear2` After equalization: `x -> mul -> InpQuantObs -> linear1 -> OutQuantObs -> linear2` Test Plan: `python test/test_quantization.py TestEqualizeFx.test_input_weight_equalization_convert` Original Model: ``` Linear2Module( (linear1): Linear(in_features=2, out_features=2, bias=True) (linear2): Linear(in_features=2, out_features=2, bias=True) ) ``` Graph after `prepare_fx`: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%x,), kwargs = {}) %x_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=x_activation_post_process_0_equalization_process_0](args = (%x_activation_post_process_0,), kwargs = {}) %linear1 : [#users=1] = call_module[target=linear1](args = (%x_activation_post_process_0_equalization_process_0,), kwargs = {}) %linear1_activation_post_process_0 : [#users=1] = call_module[target=linear1_activation_post_process_0](args = (%linear1,), kwargs = {}) %linear1_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear1_activation_post_process_0_equalization_process_0](args = (%linear1_activation_post_process_0,), kwargs = {}) %linear2 : [#users=1] = call_module[target=linear2](args = (%linear1_activation_post_process_0_equalization_process_0,), kwargs = {}) %linear2_activation_post_process_0 : [#users=1] = call_module[target=linear2_activation_post_process_0](args = (%linear2,), kwargs = {}) return linear2_activation_post_process_0 ``` Graph after equaliation functions: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_activation_post_process_0_equalization_process_0_scale : [#users=1] = get_attr[target=x_activation_post_process_0_equalization_process_0_scale] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_activation_post_process_0_equalization_process_0_scale), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%mul,), kwargs = {}) %linear1 : [#users=1] = call_module[target=linear1](args = (%x_activation_post_process_0,), kwargs = {}) %linear1_activation_post_process_0 : [#users=1] = call_module[target=linear1_activation_post_process_0](args = (%linear1,), kwargs = {}) %linear2 : [#users=1] = call_module[target=linear2](args = (%linear1_activation_post_process_0,), kwargs = {}) %linear2_activation_post_process_0 : [#users=1] = call_module[target=linear2_activation_post_process_0](args = (%linear2,), kwargs = {}) return linear2_activation_post_process_0 ``` Graph after `convert_fx`: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_activation_post_process_0_equalization_process_0_scale : [#users=1] = get_attr[target=x_activation_post_process_0_equalization_process_0_scale] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_activation_post_process_0_equalization_process_0_scale), kwargs = {}) %linear1_input_scale_0 : [#users=1] = get_attr[target=linear1_input_scale_0] %linear1_input_zero_point_0 : [#users=1] = get_attr[target=linear1_input_zero_point_0] %quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %linear1_input_scale_0, %linear1_input_zero_point_0, torch.quint8), kwargs = {}) %linear1 : [#users=1] = call_module[target=linear1](args = (%quantize_per_tensor,), kwargs = {}) %linear2 : [#users=1] = call_module[target=linear2](args = (%linear1,), kwargs = {}) %dequantize : [#users=1] = call_method[target=dequantize](args = (%linear2,), kwargs = {}) return dequantize ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D29204347 fbshipit-source-id: 6bb9e25e2468f50df523885ded2edc731f002ac1
Author
Parents
Loading