pytorch
ad87365e - [qat]A more stable conv_bn fusion for qat training. (#85744)

Commit
2 years ago
[qat]A more stable conv_bn fusion for qat training. (#85744) Summary: A more stable conv_bn fusion for qat training: * Existing implementation may cause QAT training loss become NaN. This could happen when the fused conv for qat (torch/nn/intrinsic/qat/modules/conv_fused.py) is used and is independent of if fake_quant is enabled. * This is caused by the unscaling for the conv output (`conv_orig = conv / scale_factor` where `scale_factor = bn.weight / running_std`) when there is 0 in `bn.weight`. * This implementation follows the [white paper](https://arxiv.org/pdf/1806.08342.pdf) better and fixed the issue by scaling `running_std / std_Y` instead and compute the fused output accordingly (see comments in conv_fused.py for more details): * It comes at the cost of running conv twice (one to update bn statistics and one to compute fake quant for fused weights). * It does not need to use conv bias for back prop. * It uses the bn statistics computed with the current input batch, while the existing code uses the statistics without the current batch. * The implementation could be enabled by setting the flag `_enable_slow_path_for_better_numerical_stability` to True after the model is prepared for QAT. * Unit test * Added test case for zero `bn.weight`. * Added test case for conv to has bias. Test Plan: buck run mode/dev-nosan //caffe2/test:quantization -- -r quantization.eager.test_quantize_eager_qat Differential Revision: D29506778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85744 Approved by: https://github.com/vkuzo
Author
Peizhao Zhang
Committer
Parents
Loading