[Quant][fx][bc-breaking] Integrate BackendConfig with quantization flow (part 2) (#82557)
This is part 2 of the effort to replace `backend_config_dict` with
a python config object, a more formal and robust API that leads to
better user experience. This commit integrates the `BackendConfig`
implemented in part 1 (https://github.com/pytorch/pytorch/pull/81469)
with the existing FX graph mode quantization flow.
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
BC-breaking Notes:
Before:
```
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.backend_config import ObservationType
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
dtype_config = {
"input_dtype": torch.quint8,
"output_dtype": torch.quint8
"weight_dtype": torch.qint8,
"bias_dtype": torch.float,
}
backend_config_dict = {
"name": "my_backend",
"configs": [{
"pattern": torch.nn.Linear,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [dtype_config],
"root_module": torch.nn.Linear,
"reference_quantized_module": torch.nn.quantized._reference.Linear,
"qat_module": torch.nn.qat.Linear,
}]
}
m = MyModel()
qconfig_mapping = get_default_qconfig_mapping()
example_inputs = (torch.rand(3, 3),)
m = prepare_fx(
m, qconfig_mapping, example_inputs,
backend_config_dict=backend_config_dict)
m = convert_fx(m, backend_config_dict=backend_config_dict)
```
After:
```
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.backend_config import (
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
backend_config = BackendConfig("my_backend").set_backend_pattern_config(
BackendPatternConfig(torch.nn.Linear)
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
.add_dtype_config(dtype_config)
.set_root_module(torch.nn.Linear)
.set_reference_quantized_module(torch.nn.quantized._reference.Linear)
.set_qat_module(torch.nn.qat.Linear))
m = MyModel()
qconfig_mapping = get_default_qconfig_mapping()
example_inputs = (torch.rand(3, 3),)
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
m = convert_fx(m, backend_config=backend_config)
```
Reviewers: jerryzh168
Subscribers: jerryzh168, supriyar
Differential Revision: [D38471932](https://our.internmc.facebook.com/intern/diff/D38471932)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82557
Approved by: https://github.com/jerryzh168