[quant] Speed up per-channel min-max observer (#34118)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34118
Previously calc_per_channel_qparams was using for loops and python primitives, which called `item` many times causing slowdown during training.
These changes uses torch primitives on the tensor to speed up the operation over 60x
Perf results on MobileNetV2 during training using autograd profiler
FP32 forward call -
Self CPU time total: 47.222ms
CUDA time total: 124.001ms
before change
FakeQuant Model -
Self CPU time total: 19.107s
CUDA time total: 27.177s
after change
FakeQuant Model -
Self CPU time total: 404.667ms
CUDA time total: 446.344ms
Test Plan:
python test/test_quantization.py
Imported from OSS
Differential Revision: D20287841
fbshipit-source-id: 6b706b8206e0d0da3c3c217b014e8da5b71b870d