[quant] Input Weight Equalization - prepare modifications (#59747)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59747
Modifies prepare_fx for input-weight equalization. If a current
node is being equalized (there exists a EqualizationQConfig), then the
EqualizationObserver will be inserted before its quantization observer.
For a singular linear layer, the general flow looks like:
Original graph: `x0 -> linear -> x1`, `w -> linear`
After prepare: `x0 -> InpEqObs -> MinMaxObs -> linear1 -> MinMaxObs -> x1`
`w -> WeightEqObs -> MinMaxObs -> linear1`
For two connected linear layers, the general flow looks like:
Original graph: `x0 -> linear1 -> linear2 -> x1`,
`w1 -> linear1`, `w2 -> linear2`
After prepare: `x0 -> InpEqObs -> MinMaxObs -> linear1 -> MinMaxObs -> InpEqObs -> linear2 -> MinMaxObs -> x1`
`w1 -> WeightEqObs -> MinMaxObs -> linear1`, `w2 -> WeightEqObs -> MinMaxObs -> linear2
Test Plan:
`python test/test_quantization.py
TestEqualizeFx.test_input_equalization_prepare`
Original model with one `nn.Linear` layer
```
LinearModule(
(linear): Linear(in_features=1, out_features=1, bias=True)
)
```
Graph after `prepare_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_equalization_process_0 : [#users=1] = call_module[target=x_equalization_process_0](args = (%x,), kwargs = {})
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_00](args = (%x_equalization_process_0,), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%x_activation_post_process_0,), kwargs = {})
%linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {})
return linear_activation_post_process_0
```
--------------------------------------
Original model with two connected functional linear layers
```
FunctionalLinearModule(
(linear1): Linear()
(linear2): Linear()
)
```
Graph after `prepare_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_equalization_process_0 : [#users=1] = call_module[target=x_equalization_process_0](args = (%x,), kwargs = {})
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_00](args = (%x_equalization_process_0,), kwargs = {})
%linear1_w : [#users=1] = get_attr[target=linear1.w]
%linear1_w_equalization_process_0 : [#users=1] = call_module[target=linear1_w_equalization_process_0](args = (%linear1_w,), kwargs = {})
%linear1_w_activation_post_process_0 : [#users=1] = call_module[target=linear1_w_activation_post_process_00](args = (%linear1_w_equalization_process_0,), kwargs = {})
%linear1_b : [#users=1] = get_attr[target=linear1.b]
%linear : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x_activation_post_process_0, %linear1_w_activation_post_process_0), kwargs = {bias: %linear1_b})
%linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {})
%linear_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0_equalization_process_0](args = (%linear_activation_post_process_0,), kwargs = {})
%linear2_w : [#users=1] = get_attr[target=linear2.w]
%linear2_w_equalization_process_0 : [#users=1] = call_module[target=linear2_w_equalization_process_0](args = (%linear2_w,), kwargs = {})
%linear2_w_activation_post_process_0 : [#users=1] = call_module[target=linear2_w_activation_post_process_00](args = (%linear2_w_equalization_process_0,), kwargs = {})
%linear2_b : [#users=1] = get_attr[target=linear2.b]
%linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%linear_activation_post_process_0_equalization_process_0, %linear2_w_activation_post_process_0), kwargs = {bias: %linear2_b})
%linear_1_activation_post_process_0 : [#users=1] = call_module[target=linear_1_activation_post_process_0](args = (%linear_1,), kwargs = {})
return linear_1_activation_post_process_0
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D29135316
fbshipit-source-id: 91697e805ede254dbb2a42ee4c23eb1c1c64590e