pytorch
62cad5b5 - [quant][pt2] Support cudnn_batch_norm in QAT fusion (#109908)

Commit
1 year ago
[quant][pt2] Support cudnn_batch_norm in QAT fusion (#109908) Summary: Today, we get different batch norm ops depending on the device the model is placed on at export time. Exporting `model.cpu()` gives `_native_batch_norm_legit`, while exporting `model.cuda()` gives `cudnn_batch_norm`. QAT fusion currently only supports the former and silently ignores the latter. This commit fixes this by additionally matching on the latter op during QAT fusion. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_conv_bn_fusion python test/test_quantization.py TestQuantizePT2EQAT.test_qat_conv_bn_relu_fusion Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar Differential Revision: [D49615145](https://our.internmc.facebook.com/intern/diff/D49615145) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109908 Approved by: https://github.com/jerryzh168
Author
Committer
Parents
Loading