[bc-breaking][quant][graphmode][fx] Produce dequant - fp_op - quant pattern for copy nodes (#61763)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61763
This PR changes the is_reference=True option for convert_fx to produce a dequant - fp_op - quant
pattern for copy nodes like maxpool op.
Before the PR:
```
def forward(self, x):
maxpool2d_input_scale_0 = self.maxpool2d_input_scale_0
maxpool2d_input_zero_point_0 = self.maxpool2d_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, maxpool2d_input_scale_0, maxpool2d_input_zero_point_0, torch.quint8); x = maxpool2d_input_scale_0 = maxpool2d_input_zero_point_0 = None
maxpool2d = self.maxpool2d(quantize_per_tensor); quantize_per_tensor = None
dequantize = maxpool2d.dequantize(); maxpool2d = None
return dequantize
```
After (we expand the maxpool2d that works with quantized input to "dequant - maxpool2d - quant" pattern
```
def forward(self, x):
maxpool2d_input_scale_0 = self.maxpool2d_input_scale_0
maxpool2d_input_zero_point_0 = self.maxpool2d_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, maxpool2d_input_scale_0, maxpool2d_input_zero_point_0, torch.quint8); x = maxpool2d_input_scale_0 = maxpool2d_input_zero_point_0 = None
dequantize = quantize_per_tensor.dequantize(); quantize_per_tensor = None
maxpool2d = self.maxpool2d(dequantize); dequantize = None
maxpool2d_output_scale_0 = self.maxpool2d_output_scale_0
maxpool2d_output_zero_point_0 = self.maxpool2d_output_zero_point_0
quantize_per_tensor_1 = torch.quantize_per_tensor(maxpool2d, maxpool2d_output_scale_0, maxpool2d_output_zero_point_0, torch.quint8); maxpool2d = maxpool2d_output_scale_0 = maxpool2d_output_zero_point_0 = None
dequantize_1 = quantize_per_tensor_1.dequantize(); quantize_per_tensor_1 = None
return dequantize_1
```
note that the call to self.maxpool2d is expanded to
```
dequantize = quantize_per_tensor.dequantize(); quantize_per_tensor = None
maxpool2d = self.maxpool2d(dequantize); dequantize = None
maxpool2d_output_scale_0 = self.maxpool2d_output_scale_0
maxpool2d_output_zero_point_0 = self.maxpool2d_output_zero_point_0
quantize_per_tensor_1 = torch.quantize_per_tensor(maxpool2d, maxpool2d_output_scale_0, maxpool2d_output_zero_point_0, torch.quint8); maxpool2d = maxpool2d_output_scale_0 = maxpool2d_output_zero_point_0 = None
```
Test Plan:
```
python test/test_quantization.py TestQuantizeFx.test_copy_node_has_shared_actpp_instance
```
Imported from OSS
Reviewed By: vkuzo
Differential Revision: D29728900
fbshipit-source-id: cf2c7f1f6659e3ba97cbb7c6204dd13983da10bd