pytorch
22af604e - [quant][pt2] Add Conv + BN fusion for prepare QAT (#98568)

Commit
1 year ago
[quant][pt2] Add Conv + BN fusion for prepare QAT (#98568) **Summary:** This commit adds the `prepare_qat_pt2e` API and the fusion logic for Conv + BN. We use the subgraph rewriter to match and replace the pattern with the existing logic in `nniqat.ConvBn2d`. Note this is not the end-to-end flow yet. In particular, the convert flow needs to swap the new subgraph with another one that merges the batchnorm stats back into conv. The Conv + BN fusion is implemented in the following steps: 1. Annotate all nodes in the pattern `[conv - bn - getitem]` 2. Match and replace this pattern with the fused QAT pattern (note that this is a larger subgraph than the original one) 3. Copy over metadata from the original nodes to the corresponding nodes in the new subgraph, to ensure the stack traces and dtype annotations are preserved 4. Prepare will insert fake quantizes in the right places based on the annotations **Test Plan:** python test/test_quantization.py TestQuantizePT2E.test_qat_conv_bn_fusion **Reviewers:** jerryzh168, kimishpatel, yanboliang Pull Request resolved: https://github.com/pytorch/pytorch/pull/98568 Approved by: https://github.com/kimishpatel
Author
Committer
Parents
Loading