[quant] Create FusedMovingAvgObsFakeQuantize for QAT (#61691)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61691
Create a new module for QAT that does a Fused MovingAvgMinMaxObserver and FakeQuantize operation
The module currently only supports per-tensor quantization (affine/symmetric). Follow-up PR will add support for per-channel
Results on running QAT with MobileNetV2 (Obs enabled/fake_quant enabled)
Original FQ module
PyTorchObserver {"type": "_", "metric": "qnnpack_fp_latency_ms", "unit": "ms", "value": "242.80261993408203"}
PyTorchObserver {"type": "_", "metric": "qnnpack_qat0_latency_ms", "unit": "ms", "value": "505.7964324951172"}
PyTorchObserver {"type": "_", "metric": "fbgemm_fp_latency_ms", "unit": "ms", "value": "235.80145835876465"}
PyTorchObserver {"type": "_", "metric": "fbgemm_qat0_latency_ms", "unit": "ms", "value": "543.8144207000732"}
Fused FakeQuant module (~50% improvement in latency)
PyTorchObserver {"type": "_", "metric": "qnnpack_fp_latency_ms", "unit": "ms", "value": "232.1624755859375"}
PyTorchObserver {"type": "_", "metric": "qnnpack_qat0_latency_ms", "unit": "ms", "value": "263.8866901397705"}
PyTorchObserver {"type": "_", "metric": "fbgemm_fp_latency_ms", "unit": "ms", "value": "236.9832992553711"}
PyTorchObserver {"type": "_", "metric": "fbgemm_qat0_latency_ms", "unit": "ms", "value": "292.1590805053711"}
Individual module benchmark result (>5x improvement in latency)
===> Baseline FakeQuantize module
```
--------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
--------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::fake_quantize_per_tensor_affine 0.77% 1.210ms 4.92% 7.730ms 154.596us 718.528us 0.45% 9.543ms 190.862us 50
aten::fake_quantize_per_tensor_affine_cachemask 2.41% 3.792ms 4.15% 6.520ms 130.402us 8.825ms 5.58% 8.825ms 176.492us 50
aten::_aminmax 3.25% 5.105ms 4.43% 6.955ms 139.102us 8.193ms 5.18% 8.193ms 163.868us 50
aten::zeros_like 1.87% 2.939ms 6.95% 10.922ms 109.218us 5.992ms 3.79% 10.844ms 108.442us 100
aten::zeros 0.97% 1.527ms 3.11% 4.885ms 97.702us 2.383ms 1.51% 4.800ms 96.010us 50
aten::rsub 1.34% 2.106ms 2.94% 4.614ms 92.277us 2.063ms 1.30% 4.559ms 91.173us 50
aten::clamp 2.79% 4.381ms 5.42% 8.519ms 85.190us 5.385ms 3.41% 8.438ms 84.381us 100
aten::eq 11.70% 18.384ms 21.31% 33.479ms 83.280us 22.465ms 14.21% 33.310ms 82.861us 402
aten::ones 1.05% 1.656ms 2.57% 4.038ms 80.751us 2.494ms 1.58% 3.951ms 79.028us 50
aten::le 2.52% 3.955ms 4.84% 7.607ms 76.071us 4.998ms 3.16% 7.702ms 77.016us 100
aten::min 0.69% 1.087ms 2.32% 3.641ms 72.827us 1.017ms 0.64% 3.603ms 72.055us 50
aten::max 1.40% 2.195ms 4.62% 7.260ms 72.597us 2.008ms 1.27% 7.140ms 71.404us 100
aten::is_nonzero 2.68% 4.207ms 11.35% 17.829ms 71.033us 4.062ms 2.57% 17.225ms 68.625us 251
aten::detach 1.17% 1.831ms 3.65% 5.736ms 57.360us 1.680ms 1.06% 5.634ms 56.340us 100
aten::mul 3.36% 5.278ms 3.36% 5.278ms 53.862us 5.215ms 3.30% 5.215ms 53.216us 98
aten::div 3.42% 5.376ms 3.42% 5.376ms 53.759us 5.320ms 3.36% 5.320ms 53.196us 100
aten::sub 6.79% 10.672ms 6.79% 10.672ms 53.901us 10.504ms 6.64% 10.504ms 53.050us 198
aten::item 4.06% 6.380ms 12.02% 18.883ms 53.798us 6.127ms 3.87% 18.322ms 52.198us 351
aten::add 3.28% 5.147ms 3.28% 5.147ms 52.518us 5.113ms 3.23% 5.113ms 52.171us 98
aten::minimum 1.63% 2.555ms 1.63% 2.555ms 51.092us 2.585ms 1.64% 2.585ms 51.708us 50
aten::maximum 3.22% 5.065ms 3.22% 5.065ms 50.646us 5.133ms 3.25% 5.133ms 51.329us 100
aten::round 1.61% 2.529ms 1.61% 2.529ms 50.578us 2.528ms 1.60% 2.528ms 50.552us 50
aten::zero_ 1.99% 3.125ms 4.72% 7.422ms 49.481us 2.835ms 1.79% 7.269ms 48.462us 150
aten::copy_ 6.62% 10.394ms 6.62% 10.394ms 41.576us 10.252ms 6.48% 10.252ms 41.010us 250
detach 2.49% 3.905ms 2.49% 3.905ms 39.049us 3.954ms 2.50% 3.954ms 39.539us 100
aten::select 2.01% 3.154ms 2.47% 3.876ms 38.759us 3.866ms 2.44% 3.866ms 38.658us 100
aten::_local_scalar_dense 7.96% 12.503ms 7.96% 12.503ms 35.621us 12.195ms 7.71% 12.195ms 34.743us 351
aten::to 2.31% 3.625ms 4.16% 6.530ms 32.650us 4.320ms 2.73% 6.270ms 31.348us 200
aten::fill_ 3.70% 5.808ms 3.70% 5.808ms 29.039us 5.892ms 3.73% 5.892ms 29.459us 200
aten::as_strided 0.79% 1.244ms 0.79% 1.244ms 6.221us 0.000us 0.00% 0.000us 0.000us 200
aten::empty 3.55% 5.579ms 3.55% 5.579ms 11.137us 0.000us 0.00% 0.000us 0.000us 501
aten::resize_ 2.36% 3.712ms 2.36% 3.712ms 12.332us 0.000us 0.00% 0.000us 0.000us 301
aten::empty_like 1.45% 2.284ms 3.68% 5.776ms 28.878us 0.000us 0.00% 0.000us 0.000us 200
aten::empty_strided 2.80% 4.398ms 2.80% 4.398ms 17.592us 0.000us 0.00% 0.000us 0.000us 250
--------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 157.108ms
Self CUDA time total: 158.122ms
```
===> FusedFakeQuant
```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
fb::fused_fake_quant 23.42% 6.408ms 100.00% 27.361ms 547.215us 7.887ms 27.20% 28.996ms 579.925us 50
aten::fake_quantize_per_tensor_affine 4.25% 1.162ms 27.65% 7.565ms 151.298us 686.176us 2.37% 10.217ms 204.336us 50
aten::_fake_quantize_per_tensor_affine_cachemask_ten... 14.11% 3.860ms 23.40% 6.403ms 128.068us 9.531ms 32.87% 9.531ms 190.612us 50
aten::_aminmax 20.57% 5.628ms 27.47% 7.515ms 150.305us 8.218ms 28.34% 8.218ms 164.367us 50
aten::item 3.65% 999.522us 10.27% 2.810ms 56.202us 931.904us 3.21% 2.674ms 53.481us 50
aten::_local_scalar_dense 6.62% 1.811ms 6.62% 1.811ms 36.212us 1.742ms 6.01% 1.742ms 34.843us 50
aten::empty 10.85% 2.969ms 10.85% 2.969ms 14.843us 0.000us 0.00% 0.000us 0.000us 200
aten::as_strided 1.92% 524.365us 1.92% 524.365us 5.244us 0.000us 0.00% 0.000us 0.000us 100
aten::empty_like 6.48% 1.774ms 14.62% 4.000ms 26.670us 0.000us 0.00% 0.000us 0.000us 150
aten::empty_strided 8.14% 2.226ms 8.14% 2.226ms 14.842us 0.000us 0.00% 0.000us 0.000us 150
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 27.361ms
Self CUDA time total: 28.996ms
```
Test Plan:
python test/test_quantization.py TestFusedObsFakeQuantModule
Imported from OSS
Reviewed By: vkuzo
Differential Revision: D29706889
fbshipit-source-id: ae3f9fb1fc559920459bf6e8663e8299bf7d21e1