[quant] Input-Weight Equalization - support for connected linear layers (#60034)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60034
Added support for equalizing models with connected linear
layers. To account for connected linear layers, we will additionally
multiply the previous weight values (row-wise) by the next equalization
scale, and remove the input equalization observer between the two linear
layers. We also want to scale the bias by the next equalization scale.
The math is shown here: https://fb.quip.com/fK8rA9aRM4ca .
Original Model: `x -> linear1 -> linear2`
After `prepare_fx`: `x -> InpEqObs -> InpQuantObs -> linear1 ->
OutQuantObs -> InpEqObs -> linear2`
After equalization: `x -> mul -> InpQuantObs -> linear1 -> OutQuantObs
-> linear2`
Test Plan:
`python test/test_quantization.py
TestEqualizeFx.test_input_weight_equalization_convert`
Original Model:
```
Linear2Module(
(linear1): Linear(in_features=2, out_features=2, bias=True)
(linear2): Linear(in_features=2, out_features=2, bias=True)
)
```
Graph after `prepare_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%x,), kwargs = {})
%x_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=x_activation_post_process_0_equalization_process_0](args = (%x_activation_post_process_0,), kwargs = {})
%linear1 : [#users=1] = call_module[target=linear1](args = (%x_activation_post_process_0_equalization_process_0,), kwargs = {})
%linear1_activation_post_process_0 : [#users=1] = call_module[target=linear1_activation_post_process_0](args = (%linear1,), kwargs = {})
%linear1_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear1_activation_post_process_0_equalization_process_0](args = (%linear1_activation_post_process_0,), kwargs = {})
%linear2 : [#users=1] = call_module[target=linear2](args = (%linear1_activation_post_process_0_equalization_process_0,), kwargs = {})
%linear2_activation_post_process_0 : [#users=1] = call_module[target=linear2_activation_post_process_0](args = (%linear2,), kwargs = {})
return linear2_activation_post_process_0
```
Graph after equaliation functions:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_activation_post_process_0_equalization_process_0_scale : [#users=1] = get_attr[target=x_activation_post_process_0_equalization_process_0_scale]
%mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_activation_post_process_0_equalization_process_0_scale), kwargs = {})
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%mul,), kwargs = {})
%linear1 : [#users=1] = call_module[target=linear1](args = (%x_activation_post_process_0,), kwargs = {})
%linear1_activation_post_process_0 : [#users=1] = call_module[target=linear1_activation_post_process_0](args = (%linear1,), kwargs = {})
%linear2 : [#users=1] = call_module[target=linear2](args = (%linear1_activation_post_process_0,), kwargs = {})
%linear2_activation_post_process_0 : [#users=1] = call_module[target=linear2_activation_post_process_0](args = (%linear2,), kwargs = {})
return linear2_activation_post_process_0
```
Graph after `convert_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_activation_post_process_0_equalization_process_0_scale : [#users=1] = get_attr[target=x_activation_post_process_0_equalization_process_0_scale]
%mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_activation_post_process_0_equalization_process_0_scale), kwargs = {})
%linear1_input_scale_0 : [#users=1] = get_attr[target=linear1_input_scale_0]
%linear1_input_zero_point_0 : [#users=1] = get_attr[target=linear1_input_zero_point_0]
%quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %linear1_input_scale_0, %linear1_input_zero_point_0, torch.quint8), kwargs = {})
%linear1 : [#users=1] = call_module[target=linear1](args = (%quantize_per_tensor,), kwargs = {})
%linear2 : [#users=1] = call_module[target=linear2](args = (%linear1,), kwargs = {})
%dequantize : [#users=1] = call_method[target=dequantize](args = (%linear2,), kwargs = {})
return dequantize
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D29204347
fbshipit-source-id: 6bb9e25e2468f50df523885ded2edc731f002ac1