pytorch
e8f1f4ed - [quant][pt2][ROCm] follow-up PR 109908 for miopen_batch_norm (#110653)

Commit
1 year ago
[quant][pt2][ROCm] follow-up PR 109908 for miopen_batch_norm (#110653) Fixes recent broken unit tests caused by PR #109908 because cudnn and miopen have separate batch norm functions. ``` 2023-10-05T09:35:01.6606614Z _______________ TestQuantizePT2EQAT.test_qat_conv_bn_fusion_cuda _______________ 2023-10-05T09:35:01.6606948Z Traceback (most recent call last): 2023-10-05T09:35:01.6607362Z File "/var/lib/jenkins/pytorch/test/quantization/pt2e/test_quantize_pt2e_qat.py", line 323, in test_qat_conv_bn_fusion_cuda 2023-10-05T09:35:01.6607767Z self._verify_symmetric_xnnpack_qat_graph( 2023-10-05T09:35:01.6608217Z File "/var/lib/jenkins/pytorch/test/quantization/pt2e/test_quantize_pt2e_qat.py", line 130, in _verify_symmetric_xnnpack_qat_graph 2023-10-05T09:35:01.6608658Z self._verify_symmetric_xnnpack_qat_graph_helper( 2023-10-05T09:35:01.6609105Z File "/var/lib/jenkins/pytorch/test/quantization/pt2e/test_quantize_pt2e_qat.py", line 173, in _verify_symmetric_xnnpack_qat_graph_helper 2023-10-05T09:35:01.6609623Z m = prepare_qat_pt2e(m, quantizer) 2023-10-05T09:35:01.6610171Z File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/ao/quantization/quantize_pt2e.py", line 178, in prepare_qat_pt2e 2023-10-05T09:35:01.6610561Z _fuse_conv_bn_qat(model) 2023-10-05T09:35:01.6611072Z File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/ao/quantization/pt2e/qat_utils.py", line 501, in _fuse_conv_bn_qat 2023-10-05T09:35:01.6611497Z m = _fuse_conv_bn_qat_helper(m, is_cuda=True) 2023-10-05T09:35:01.6612065Z File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/ao/quantization/pt2e/qat_utils.py", line 575, in _fuse_conv_bn_qat_helper 2023-10-05T09:35:01.6612492Z _get_conv_bn_getitem_nodes(r.replacements) 2023-10-05T09:35:01.6613058Z File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/ao/quantization/pt2e/qat_utils.py", line 383, in _get_conv_bn_getitem_nodes 2023-10-05T09:35:01.6613465Z assert bn_node is not None 2023-10-05T09:35:01.6613716Z AssertionError ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/110653 Approved by: https://github.com/jerryzh168, https://github.com/pruthvistony
Author
Committer
Parents
Loading