pytorch
691a44f4 - [Quant][fx][bc-breaking] Add simpler BackendConfig pattern format (#90698)

Commit
2 years ago
[Quant][fx][bc-breaking] Add simpler BackendConfig pattern format (#90698) Summary: The existing BackendConfig fusion pattern uses a "reversed nested tuple" format that is highly unintuitive. For example, ``` linear-relu -> (nn.ReLU, nn.Linear) conv-bn-relu -> (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) ``` This pattern format also complicates the signatures of the user specified "fuser methods", which needed to accept arguments in reverse nested order to match the patterns: ``` def fuse_linear_relu(is_qat, relu, linear): ... def fuse_conv_bn_relu(is_qat, relu, bn_conv): (bn, conv) = bn_conv ... ``` Instead, this commit introduces a new pattern format that simply specifies the ops in forward order with no nesting: ``` linear-relu -> (nn.Linear, nn.ReLU) conv-bn-relu -> (nn.Conv2d, nn.BatchNorm2d, nn.ReLU) def fuse_linear_relu(is_qat, linear, relu): ... def fuse_conv_bn_relu(is_qat, conv, bn, relu): ... ``` Note that the legacy "reversed nested tuple" is still used internally since it is more general. In the future, we should replace it with the format used in the subgraph rewriter in `torch.fx`, and simplify the existing pattern matching code to handle the new format added in this commit. BC-breaking Notes: Before: ``` import torch as nn import torch.ao.nn.intrinsic as nni from torch.ao.quantization.backend_config import BackendPatternConfig def fuse_linear_relu(is_qat, relu, bn_conv): (bn, conv) = bn_conv return nni.ConvBnReLU2d(conv, bn, relu) config = BackendPatternConfig((nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))) \ .set_dtype_configs(...) \ .set_fuser_method(fuse_conv_bn_relu) \ .set_fused_module(nni.ConvBnReLU2d) ``` After: ``` def fuse_linear_relu(is_qat, conv, bn, relu): return nni.ConvBnReLU2d(conv, bn, relu) config = BackendPatternConfig((nn.Conv2d, nn.BatchNorm2d, nn.ReLU)) \ .set_dtype_configs(...) \ .set_fuser_method(fuse_conv_bn_relu) \ .set_fused_module(nni.ConvBnReLU2d) ``` OR (for backward-compatibility) ``` def fuse_linear_relu(is_qat, relu, bn_conv): (bn, conv) = bn_conv return nni.ConvBnReLU2d(conv, bn, relu) config = BackendPatternConfig() \ ._set_pattern_complex_format((nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))) \ .set_dtype_configs(...) \ .set_fuser_method(fuse_conv_bn_relu) \ .set_fused_module(nni.ConvBnReLU2d) \ ._set_use_legacy_pattern_format(True) ``` Before: ``` backend_config.configs # returns Dict[Pattern, BackendPatternConfig] ``` After: ``` backend_config.configs # returns List[BackendPatternConfig] ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestBackendConfig Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo Differential Revision: [D41954553](https://our.internmc.facebook.com/intern/diff/D41954553) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90698 Approved by: https://github.com/vkuzo, https://github.com/jerryzh168
Author
Committer
Parents
Loading