[quant] Input-Weight Equalization - support for connected F.linear layer (#60272)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60272
Test Plan:
`python test/test_quantization.py TestEqualizeFx`
Original model:
```
FunctionalLinear2Module(
(linear1): Linear()
(linear2): Linear()
)
```
Graph after `prepare_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%x,), kwargs = {})
%x_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=x_activation_post_process_0_equalization_process_0](args = (%x_activation_post_process_0,), kwargs = {})
%linear1_w : [#users=1] = get_attr[target=linear1.w]
%linear1_w_activation_post_process_0 : [#users=1] = call_module[target=linear1_w_activation_post_process_0](args = (%linear1_w,), kwargs = {})
%linear1_w_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear1_w_activation_post_process_0_equalization_process_0](args = (%linear1_w_activation_post_process_0,), kwargs = {})
%linear1_b : [#users=1] = get_attr[target=linear1.b]
%linear : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x_activation_post_process_0_equalization_process_0, %linear1_w_activation_post_process_0_equalization_process_0), kwargs = {bias: %linear1_b})
%linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {})
%linear_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0_equalization_process_0](args = (%linear_activation_post_process_0,), kwargs = {})
%linear2_w : [#users=1] = get_attr[target=linear2.w]
%linear2_w_activation_post_process_0 : [#users=1] = call_module[target=linear2_w_activation_post_process_0](args = (%linear2_w,), kwargs = {})
%linear2_w_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear2_w_activation_post_process_0_equalization_process_0](args = (%linear2_w_activation_post_process_0,), kwargs = {})
%linear2_b : [#users=1] = get_attr[target=linear2.b]
%linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%linear_activation_post_process_0_equalization_process_0, %linear2_w_activation_post_process_0_equalization_process_0), kwargs = {bias: %linear2_b})
%linear_1_activation_post_process_0 : [#users=1] = call_module[target=linear_1_activation_post_process_0](args = (%linear_1,), kwargs = {})
return linear_1_activation_post_process_0
```
Graph after equalization steps:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0]
%mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {})
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%mul,), kwargs = {})
%linear1_w : [#users=1] = get_attr[target=linear1.w]
%linear1_w_activation_post_process_0 : [#users=1] = call_module[target=linear1_w_activation_post_process_0](args = (%linear1_w,), kwargs = {})
%linear1_b : [#users=1] = get_attr[target=linear1.b]
%linear : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x_activation_post_process_0, %linear1_w_activation_post_process_0), kwargs = {bias: %linear1_b})
%linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {})
%linear2_w : [#users=1] = get_attr[target=linear2.w]
%linear2_w_activation_post_process_0 : [#users=1] = call_module[target=linear2_w_activation_post_process_0](args = (%linear2_w,), kwargs = {})
%linear2_b : [#users=1] = get_attr[target=linear2.b]
%linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%linear_activation_post_process_0, %linear2_w_activation_post_process_0), kwargs = {bias: %linear2_b})
%linear_1_activation_post_process_0 : [#users=1] = call_module[target=linear_1_activation_post_process_0](args = (%linear_1,), kwargs = {})
return linear_1_activation_post_process_0
```
Graph after `convert_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0]
%mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {})
%linear1_input_scale_0 : [#users=1] = get_attr[target=linear1_input_scale_0]
%linear1_input_zero_point_0 : [#users=1] = get_attr[target=linear1_input_zero_point_0]
%quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %linear1_input_scale_0, %linear1_input_zero_point_0, torch.quint8), kwargs = {})
%linear1_packed_weight_0 : [#users=1] = get_attr[target=linear1_packed_weight_0]
%linear1_scale_0 : [#users=1] = get_attr[target=linear1_scale_0]
%linear1_zero_point_0 : [#users=1] = get_attr[target=linear1_zero_point_0]
%linear : [#users=1] = call_function[target=torch.ops.quantized.linear](args = (%quantize_per_tensor, %linear1_packed_weight_0, %linear1_scale_0, %linear1_zero_point_0), kwargs = {})
%linear2_packed_weight_0 : [#users=1] = get_attr[target=linear2_packed_weight_0]
%linear2_scale_0 : [#users=1] = get_attr[target=linear2_scale_0]
%linear2_zero_point_0 : [#users=1] = get_attr[target=linear2_zero_point_0]
%linear_1 : [#users=1] = call_function[target=torch.ops.quantized.linear](args = (%linear, %linear2_packed_weight_0, %linear2_scale_0, %linear2_zero_point_0), kwargs = {})
%dequantize : [#users=1] = call_method[target=dequantize](args = (%linear_1,), kwargs = {})
return dequantize
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D29267218
fbshipit-source-id: 6b97bed1a307f1d0b1f5efcbecf41f35418242f7