pytorch
3de79b77 - [quant] Input-Weight Equaliaztion - convert modifications (#59963)

Commit
3 years ago
[quant] Input-Weight Equaliaztion - convert modifications (#59963) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59963 When converting, before quantizing the nodes, we call `update_obs_for_equalization()` and `convert_eq_obs()`. `update_obs_for_equalization`: 1. For each InputEqualizationObserver, we find the corresponding WeightEqualizationObserver. 2. For nn.Linear layers, we will create an instance of the WeightEqualizationObserver, run forward on the observer with the given weights. 3. Calculate the equalization scale between the InputEqualizationObserver and WeightEqualizationObserver. `convert_eq_obs`: For every InputEqualizationObserver, we will do the following: 1. Create a node (ex. `x0_activation_post_process_scale`) containing the equalization scale constant. 2. Create another node containing a `mul` operator multiplying the equalization scale and the input. 3. Remove the current InputEqualizationObserver node, and replace it with the `mul` node. For every WeightEqualizationObserver, we will do the following: 1. Get the next equalization scale (we may need this for equalizing connected linear layers). 2. Scale the weights by multiplying it with the reciprocal of the current equalization scale and the next equalization scale Currently, this supports models with `nn.Linear` layers, but does not support connecting linear layers. Test Plan: `python test/test_quantization.py TestEqualizeFx.test_input_weight_equalization_convert` Original Model: ``` .LinearModule( (linear): Linear(in_features=2, out_features=2, bias=True) ) ``` Graph after `prepare_fx`: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_process_0 : [#users=1] = call_module[target=x_equalization_process_0](args = (%x,), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_00](args = (%x_equalization_process_0,), kwargs = {}) %linear : [#users=1] = call_module[target=linear](args = (%x_activation_post_process_0,), kwargs = {}) %linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {}) return linear_activation_post_process_0 ``` Graph after equalization functions: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_process_0_scale : [#users=1] = get_attr[target=x_equalization_process_0_scale] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_process_0_scale), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_00](args = (%mul,), kwargs = {}) %linear : [#users=1] = call_module[target=linear](args = (%x_activation_post_process_0,), kwargs = {}) %linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {}) return linear_activation_post_process_0 ``` Graph after `convert_fx`: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_process_0_scale : [#users=1] = get_attr[target=x_equalization_process_0_scale] %mul : [#users=1] = call_function[target=torch.mul](args = (%x, %x_equalization_process_0_scale), kwargs = {}) %linear_input_scale_0 : [#users=1] = get_attr[target=linear_input_scale_0] %linear_input_zero_point_0 : [#users=1] = get_attr[target=linear_input_zero_point_0] %quantize_per_tensor : [#users=1] = call_function[target=torch.quantize_per_tensor](args = (%mul, %linear_input_scale_0, %linear_input_zero_point_0, torch.quint8), kwargs = {}) %linear : [#users=1] = call_module[target=linear](args = (%quantize_per_tensor,), kwargs = {}) %dequantize : [#users=1] = call_method[target=dequantize](args = (%linear,), kwargs = {}) return dequantize ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D29135358 fbshipit-source-id: 2d00056729041318463de61841483490b6bfeee5
Author
Parents
Loading