pytorch
782f3489 - [Quant][fx][bc-breaking] Integrate BackendConfig with quantization flow (part 2) (#82557)

Commit
3 years ago
[Quant][fx][bc-breaking] Integrate BackendConfig with quantization flow (part 2) (#82557) This is part 2 of the effort to replace `backend_config_dict` with a python config object, a more formal and robust API that leads to better user experience. This commit integrates the `BackendConfig` implemented in part 1 (https://github.com/pytorch/pytorch/pull/81469) with the existing FX graph mode quantization flow. Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps BC-breaking Notes: Before: ``` import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.backend_config import ObservationType from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx dtype_config = { "input_dtype": torch.quint8, "output_dtype": torch.quint8 "weight_dtype": torch.qint8, "bias_dtype": torch.float, } backend_config_dict = { "name": "my_backend", "configs": [{ "pattern": torch.nn.Linear, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [dtype_config], "root_module": torch.nn.Linear, "reference_quantized_module": torch.nn.quantized._reference.Linear, "qat_module": torch.nn.qat.Linear, }] } m = MyModel() qconfig_mapping = get_default_qconfig_mapping() example_inputs = (torch.rand(3, 3),) m = prepare_fx( m, qconfig_mapping, example_inputs, backend_config_dict=backend_config_dict) m = convert_fx(m, backend_config_dict=backend_config_dict) ``` After: ``` import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.backend_config import ( BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType, ) from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8 weight_dtype=torch.qint8, bias_dtype=torch.float, ) backend_config = BackendConfig("my_backend").set_backend_pattern_config( BackendPatternConfig(torch.nn.Linear) .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .add_dtype_config(dtype_config) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(torch.nn.quantized._reference.Linear) .set_qat_module(torch.nn.qat.Linear)) m = MyModel() qconfig_mapping = get_default_qconfig_mapping() example_inputs = (torch.rand(3, 3),) m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) m = convert_fx(m, backend_config=backend_config) ``` Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar Differential Revision: [D38471932](https://our.internmc.facebook.com/intern/diff/D38471932) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82557 Approved by: https://github.com/jerryzh168
Author
Committer
Parents
Loading