[quant][fx] Remove extra q-dq for weight bias in normalization ops (#59882)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59882
Currently for normalization ops, the weight and bias arguments are treated as activationn inputs which require observers.
This results in adding extra quant-dequant ops for the weight and bias inputs.
This PR adds support to skip observing weight/bias inputs of norm operators, thus removing the redundant q-dq ops
Quantized graph with F.layer_norm
Before this PR
```
def forward(self, x):
_input_scale_0 = self._input_scale_0
_input_zero_point_0 = self._input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, _input_scale_0, _input_zero_point_0, torch.quint8); x = _input_scale_0 = _input_zero_point_0 = None
scale = self.scale
_input_scale_1 = self._input_scale_1
_input_zero_point_1 = self._input_zero_point_1
quantize_per_tensor_1 = torch.quantize_per_tensor(scale, _input_scale_1, _input_zero_point_1, torch.quint8); scale = _input_scale_1 = _input_zero_point_1 = None
bias = self.bias
_input_scale_2 = self._input_scale_2
_input_zero_point_2 = self._input_zero_point_2
quantize_per_tensor_2 = torch.quantize_per_tensor(bias, _input_scale_2, _input_zero_point_2, torch.quint8); bias = _input_scale_2 = _input_zero_point_2 = None
_scale_0 = self._scale_0
_zero_point_0 = self._zero_point_0
dequantize = quantize_per_tensor_1.dequantize(); quantize_per_tensor_1 = None
dequantize_1 = quantize_per_tensor_2.dequantize(); quantize_per_tensor_2 = None
layer_norm = torch.ops.quantized.layer_norm(quantize_per_tensor, [2, 5, 5], weight = dequantize, bias = dequantize_1, eps = 1e-05, output_scale = _scale_0, output_zero_point = _zero_point_0); quantize_per_tensor = dequantize = dequantize_1 = _scale_0 = _zero_point_0 = None
dequantize_2 = layer_norm.dequantize(); layer_norm = None
return dequantize_2
```
After
```
def forward(self, x):
_input_scale_0 = self._input_scale_0
_input_zero_point_0 = self._input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, _input_scale_0, _input_zero_point_0, torch.quint8); x = _input_scale_0 = _input_zero_point_0 = None
scale = self.scale
bias = self.bias
_scale_0 = self._scale_0
_zero_point_0 = self._zero_point_0
layer_norm = torch.ops.quantized.layer_norm(quantize_per_tensor, [2, 5, 5], weight = scale, bias = bias, eps = 1e-05, output_scale = _scale_0, output_zero_point = _zero_point_0); quantize_per_tensor = scale = bias = _scale_0 = _zero_point_0 = None
dequantize = layer_norm.dequantize(); layer_norm = None
return dequantize
```
Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_norm_weight_bias
Imported from OSS
Reviewed By: HDCharles, ailzhang
Differential Revision: D29068203
fbshipit-source-id: 24b5c38bbea5fd355d34522bfa654c9db18607da