[quant] Input-Weight Equalization - support for LinearReLU layers (#60653)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60653
Special casing was needed to get the weight attribute in the linear layers of fused LinearReLU layers.
Initial Model: `x -> linear1 -> relu`
After fusion: `x -> linearRelu`
After prepare: `x -> input_quant_obs -> input_eq_obs1 -> linearRelu -> output_quant_obs1`
After equalization functions: `x -> mul -> input_quant_obs (scaled) -> linearRelu -> output_quant_obs`
After convert: `x -> mul -> quantize_per_tensor -> quantized::linearRelu -> dequantize`
More step-throughs here: https://fb.quip.com/A9J3AsBxkykR
Test Plan:
`python test/test_quantization.py TestEqualizeFx`
Original model:
```
LinearReluModel(
(fc): Linear(in_features=5, out_features=5, bias=True)
(relu): ReLU()
)
```
Graph after `prepare_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%x,), kwargs = {})
%x_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=x_activation_post_process_0_equalization_process_0](args = (%x_activation_post_process_0,), kwargs = {})
%fc : [#users=1] = call_module[target=fc](args = (%x_activation_post_process_0_equalization_process_0,), kwargs = {})
%fc_activation_post_process_0 : [#users=1] = call_module[target=fc_activation_post_process_0](args = (%fc,), kwargs = {})
return fc_activation_post_process_0
```
Graph after equalization functions:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0]
%mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {})
%x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_0](args = (%mul,), kwargs = {})
%fc : [#users=1] = call_module[target=fc](args = (%x_activation_post_process_0,), kwargs = {})
%fc_activation_post_process_0 : [#users=1] = call_module[target=fc_activation_post_process_0](args = (%fc,), kwargs = {})
return fc_activation_post_process_0
```
Graph after `convert_fx`:
```
graph():
%x : [#users=1] = placeholder[target=x]
%x_equalization_scale0 : [#users=1] = get_attr[target=x_equalization_scale0]
%mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_scale0), kwargs = {})
%fc_input_scale_0 : [#users=1] = get_attr[target=fc_input_scale_0]
%fc_input_zero_point_0 : [#users=1] = get_attr[target=fc_input_zero_point_0]
%quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %fc_input_scale_0, %fc_input_zero_point_0, torch.quint8), kwargs = {})
%fc : [#users=1] = call_module[target=fc](args = (%quantize_per_tensor,), kwargs = {})
%dequantize : [#users=1] = call_method[target=dequantize](args = (%fc,), kwargs = {})
return dequantize
```
Imported from OSS
Reviewed By: supriyar
Differential Revision: D29406999
fbshipit-source-id: add38e8e7fb84a241c3b10bfb8451b50103effd4