pytorch
a029422c - [quant][graphmode][fx][refactor] Change the env map to add dtype as a key (#60054)

Commit
4 years ago
[quant][graphmode][fx][refactor] Change the env map to add dtype as a key (#60054) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60054 Previously env in convert is Dict[str, Tuple[Node, torch.dtype]], that is, at a given time each node can only have one dtype, this causes a problem for the following case: ``` class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv(x) x1 = x.expand_as(x) x2 = torch.add(x, x1) return x2 def forward(self, x): x = self.activation_post_process_0(x) x = self.conv(x) x = self.activation_post_process_1(x) x1 = x.expand_as(x) x1 = self.activation_post_process_2(x1) x2 = torch.add(x, x1) x2 = self.activation_post_process_3(x2) return x2 def forward(self, x): x = torch.quantize_per_tensor(x, ...) x = self.conv(x). # quantized conv x = torch.dequantize(x) x1 = x.expand_as(x) x1 = torch.quantize_per_tensor(x1, ...) # Error: x is dequantized x2 = torch.ops.quantized.add(x, x1) return x2 Currently we have a env that is a map from node name of the observed graph to the Node in the quantized graph, here the problem is that following a quantized operator conv, we have two operators, one is expecting float input (expand_as), the other is expecting quantized input (quantized add), and in the quantized graph, ideally, expand_as should consume the dequantized output, and quantized add should consume the quantized output: quantized_conv - dequantize - expand_as \ ------- quantized_add But currently in env, each node needs to either be quantized or not quantized. Therefore we will need to change env to include dtype as well: env: Dict[str, Dict[dtype, Node]], e.g. {‘x’: {torch.float: dequantized_node, torch.quint8: quantized_node}} And when we load from the env, we will need to provide the dtype of the Node that we want to load as well. We can have a separate pass to figure out this information for each node. ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Imported from OSS Reviewed By: vkuzo Differential Revision: D29149408 fbshipit-source-id: c9e4b7d65444ab6a6f573929bae1db5037629892
Author
Parents
Loading