pytorch
457a3fb6 - [bc-breaking][quant][graphmode][fx] Produce dequant - fp_op - quant pattern for copy nodes (#61763)

Commit
4 years ago
[bc-breaking][quant][graphmode][fx] Produce dequant - fp_op - quant pattern for copy nodes (#61763) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61763 This PR changes the is_reference=True option for convert_fx to produce a dequant - fp_op - quant pattern for copy nodes like maxpool op. Before the PR: ``` def forward(self, x): maxpool2d_input_scale_0 = self.maxpool2d_input_scale_0 maxpool2d_input_zero_point_0 = self.maxpool2d_input_zero_point_0 quantize_per_tensor = torch.quantize_per_tensor(x, maxpool2d_input_scale_0, maxpool2d_input_zero_point_0, torch.quint8); x = maxpool2d_input_scale_0 = maxpool2d_input_zero_point_0 = None maxpool2d = self.maxpool2d(quantize_per_tensor); quantize_per_tensor = None dequantize = maxpool2d.dequantize(); maxpool2d = None return dequantize ``` After (we expand the maxpool2d that works with quantized input to "dequant - maxpool2d - quant" pattern ``` def forward(self, x): maxpool2d_input_scale_0 = self.maxpool2d_input_scale_0 maxpool2d_input_zero_point_0 = self.maxpool2d_input_zero_point_0 quantize_per_tensor = torch.quantize_per_tensor(x, maxpool2d_input_scale_0, maxpool2d_input_zero_point_0, torch.quint8); x = maxpool2d_input_scale_0 = maxpool2d_input_zero_point_0 = None dequantize = quantize_per_tensor.dequantize(); quantize_per_tensor = None maxpool2d = self.maxpool2d(dequantize); dequantize = None maxpool2d_output_scale_0 = self.maxpool2d_output_scale_0 maxpool2d_output_zero_point_0 = self.maxpool2d_output_zero_point_0 quantize_per_tensor_1 = torch.quantize_per_tensor(maxpool2d, maxpool2d_output_scale_0, maxpool2d_output_zero_point_0, torch.quint8); maxpool2d = maxpool2d_output_scale_0 = maxpool2d_output_zero_point_0 = None dequantize_1 = quantize_per_tensor_1.dequantize(); quantize_per_tensor_1 = None return dequantize_1 ``` note that the call to self.maxpool2d is expanded to ``` dequantize = quantize_per_tensor.dequantize(); quantize_per_tensor = None maxpool2d = self.maxpool2d(dequantize); dequantize = None maxpool2d_output_scale_0 = self.maxpool2d_output_scale_0 maxpool2d_output_zero_point_0 = self.maxpool2d_output_zero_point_0 quantize_per_tensor_1 = torch.quantize_per_tensor(maxpool2d, maxpool2d_output_scale_0, maxpool2d_output_zero_point_0, torch.quint8); maxpool2d = maxpool2d_output_scale_0 = maxpool2d_output_zero_point_0 = None ``` Test Plan: ``` python test/test_quantization.py TestQuantizeFx.test_copy_node_has_shared_actpp_instance ``` Imported from OSS Reviewed By: vkuzo Differential Revision: D29728900 fbshipit-source-id: cf2c7f1f6659e3ba97cbb7c6204dd13983da10bd
Author
Parents
Loading