Implement QAT for APoT (#83282)
### Summary:
This PR implements QAT for APoT FakeQuant. It runs QAT with FX graph mode quantized models (Resnet-18 pre-trained model, full ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. It also refactors the APoT PTQ module `apot_fx_graph_mode_ptq.py` (previously `fx_graph_mode_apot.py`) such that shared helper functions between PTQ and QAT are in a separate file `quantization_util.py`.
Model #2 (uniformly quantized activation, APoT quantized weight) shows comparable accuracy compared to model #1 (uniformly quantized activation, APoT quantized weight) for 8-bit and significant accuracy improvement for 4-bit (see "Accuracy Stats" section below).
### Test Plan:
Run QAT models with: `python test/quantization/core/experimental/apot_qat.py`
Run PTQ models with: `python test/quantization/core/experimental/apot_ptq.py`
### Accuracy Stats
8-bit (Uniform int8, APoT b = 8 k = 2)
Model #1: Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 69.67% (Top-1), 89.04% (Top-5)
Model #2: Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 69.72% (Top-1), 89.06% (Top-5)
4-bit (Uniform int4, APoT b = 4 k = 2)
Model #1: Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 46.85% (Top-1), 72.85% (Top-5)
Model #2: Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 66.45% (Top-1), 86.23% (Top-5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83282
Approved by: https://github.com/jerryzh168