ns for fx: expose function to add comparisons between logged values (#60311)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60311
Adds a user facing utility function to FX Numeric Suite Core APIs
for comparing the values extracted by the loggers to each other.
This is needed for any kind of analysis, so would be great to
provide an example implementation.
Example:
```
// code
m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
qconfig_dict = {'': torch.quantization.default_qconfig}
mp = torch.quantization.quantize_fx.prepare_fx(m, qconfig_dict)
mq = torch.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
results = extract_weights('fp32', mp, 'int8', mq)
extend_logger_results_with_comparison(
results, 'fp32', 'int8', compute_sqnr, 'sqnr_int8_vs_fp32')
print(results)
// results
{
'_1': {'weight': {
'fp32': [
{'type': 'weight', 'values': [tensor([[[[-0.3284]]]])], 'prev_node_name': '_1', 'prev_node_target_type': "<class 'torch.nn.modules.conv.Conv2d'>", 'ref_node_name': '_1', 'index_within_arg': 0, 'index_of_arg': 0}
],
'int8': [
{'type': 'weight', 'values': [tensor([[[[-0.3297]]]], size=(1, 1, 1, 1), dtype=torch.qint8,
quantization_scheme=torch.per_tensor_affine, scale=0.002575645223259926,
zero_point=0)], 'prev_node_name': '_1', 'prev_node_target_type': "<class 'torch.nn.quantized.modules.conv.Conv2d'>", 'ref_node_name': '_1', 'index_within_arg': 0, 'index_of_arg': 0, 'sqnr_int8_vs_fp32': [tensor(48.1308)]}
]
}},
'_0': {'weight': {
'fp32': [{'type': 'weight', 'values': [tensor([[[[0.5205]]]])], 'prev_node_name': '_0', 'prev_node_target_type': "<class 'torch.nn.modules.conv.Conv2d'>", 'ref_node_name': '_0', 'index_within_arg': 0, 'index_of_arg': 0}],
'int8': [{'type': 'weight', 'values': [tensor([[[[0.5184]]]], size=(1, 1, 1, 1), dtype=torch.qint8,
quantization_scheme=torch.per_tensor_affine, scale=0.004082232713699341,
zero_point=0)], 'prev_node_name': '_0', 'prev_node_target_type': "<class 'torch.nn.quantized.modules.conv.Conv2d'>", 'ref_node_name': '_0', 'index_within_arg': 0, 'index_of_arg': 0, 'sqnr_int8_vs_fp32': [tensor(48.1309)]}]
}}
}
```
Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_extend_logger_results_with_comparison
```
Imported from OSS
Reviewed By: hx89
Differential Revision: D29244715
fbshipit-source-id: a5547b449ea54e046c752119559be49bd738beea