[qat]A more stable conv_bn fusion for qat training. (#85744)
Summary:
A more stable conv_bn fusion for qat training:
* Existing implementation may cause QAT training loss become NaN. This could happen when the fused conv for qat (torch/nn/intrinsic/qat/modules/conv_fused.py) is used and is independent of if fake_quant is enabled.
* This is caused by the unscaling for the conv output (`conv_orig = conv / scale_factor` where `scale_factor = bn.weight / running_std`) when there is 0 in `bn.weight`.
* This implementation follows the [white paper](https://arxiv.org/pdf/1806.08342.pdf) better and fixed the issue by scaling `running_std / std_Y` instead and compute the fused output accordingly (see comments in conv_fused.py for more details):
* It comes at the cost of running conv twice (one to update bn statistics and one to compute fake quant for fused weights).
* It does not need to use conv bias for back prop.
* It uses the bn statistics computed with the current input batch, while the existing code uses the statistics without the current batch.
* The implementation could be enabled by setting the flag `_enable_slow_path_for_better_numerical_stability` to True after the model is prepared for QAT.
* Unit test
* Added test case for zero `bn.weight`.
* Added test case for conv to has bias.
Test Plan: buck run mode/dev-nosan //caffe2/test:quantization -- -r quantization.eager.test_quantize_eager_qat
Differential Revision: D29506778
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85744
Approved by: https://github.com/vkuzo