Quantization aware training in eager mode (#22732)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22732
Add support for quantization aware training in eager mode
Modifications to Post training flow:
## Prepare
* Fusion: e.g. (Conv, Bn) → ConvBn (float)
* Swapping: To insert fake_quant to weight, we need to swap the float modules that has weight with different qat modules, e.g. Conv → torch.nn.qat.Conv , ConvBn → torch.nn._intrinsic.qat.ConvBn
```
* previously we were thinking about modify the weight in forward_pre hook and change it back in forward_hook:
* def forward_pre_hook(self, input):
self.float_weight = self.weight
self.weight = self.fake_quantize(self.float_weight)
def forward_hook(self, input):
self.weight = self.float_weight
```
* Assignments to self.weight are needed because we can’t change forward function and in forward function they are using self.weight.
* But we will need to keep two copies of weight in this case, so it’s probably better to just swap the module
* So we want to just swap Conv to torch.nn.qat.Conv and Linear to torch.nn.qat.Linear
* qat modules will have fake_quant for output and weights inserted in forward function
## Convert
* flow should be identical to ptq, but the swapping dictionary is slightly different since modules are changed in prepare step.
Reviewed By: zafartahirov
Differential Revision: D16199356
fbshipit-source-id: 62aeaf47c12c62a87d9cac208f25f7592e245d6c