pytorch
b86554ab - [quant][fx] Fix dynamic weighted op lowering when input is used multiple times (#74364)

Commit
2 years ago
[quant][fx] Fix dynamic weighted op lowering when input is used multiple times (#74364) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74364 if a input is used multiple times in modules that are dynamically quantized: ``` x -- linear1 \-- linear2 ``` we'll insert quantize_per_tensor_dynamic and dequantize for input, and we'll have a duplicate pass to duplicate dequantize ops for pattern matching: ``` x - quantize_per_tensor_dynamic - dequantize1 - linear1 \----- dequantize2 - linear2 ``` But we also have a check in the lowering code that if quantize_per_tensor_dynamic is used by multiple nodes we'll skip the pattern, so the pattern is not recognized, we need to duplicate quantize_per_tensor_dynamic as well in this case to recover both patterns: ``` x - quantize_per_tensor_dynamic1 -- dequantize1 -- linear1 \- quantize_per-tensor_dynamic2 -- dequantize2 -- linear2 ``` so that they can be fused into dynamic linear: ``` x - linear_dynamic1 \-- linear_dynamic2 ``` Test Plan: python test/test_quantization.py TestQuantizeFx.test_dynamic_linear_input_multiple_use Imported from OSS Reviewed By: yixin94 Differential Revision: D34952755 fbshipit-source-id: a950159fd6a661e84faf0baf1692f6783904cfb3 (cherry picked from commit 8a6896801fdd96a55476faca4ccb7ba0b0bdb058)
Author
Committer
Parents
Loading