Generalized Learnable Fake Quantizer Module (#41535)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41535
A generalized fake quantization module is built to support lower-bit fake quantization with back propagation on the scale and zero point. The module supports both per tensor and per channel fake quantization.
Test Plan:
Please see diff D22337313 for a related experiment performed on the fake quantizer module.
The `_LearnableFakeQuantize` module supports the following use cases:
- Per Tensor Fake Quantization or Per Channel Fake Quantization
- Static Estimation from Observers or Quantization Parameter Learning through Back Propagation
By default, the module assumes per tensor affine fake quantization. To switch to per channel, during initialization, declare `channel_size` with the appropriate length. To toggle between utilizing static estimation and parameter learning with back propagation, you can invoke the call `enable_param_learning` or `enable_static_estimate`. For more information on the flags that support these operations, please see the doc string of the `_LearnableFakeQuantize` module.
The `_LearnableFakeQuantizer` module relies on 2 operators for its forward and backward paths: `_LearnableFakeQuantizePerTensorOp` and `_LearnableFakeQuantizePerChannelOp`. The backpropagation routine is developed based on the following literature:
- Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
- Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
Reviewed By: z-a-f
Differential Revision: D22573645
fbshipit-source-id: cfd9ece8a959ae31c00d9beb1acf9dfed71a7ea1