fake_quant: add a more memory efficient version (#50561)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50561
Not for review yet, a bunch of TODOs need finalizing.
tl;dr; add an alternative implementation of `fake_quantize` which saves
a ask during the forward pass and uses it to calculate the backward.
There are two benefits:
1. the backward function no longer needs the input Tensor, and it can be
gc'ed earlier by autograd. On MobileNetV2, this reduces QAT overhead
by ~15% (TODO: link, and absolute numbers). We add an additional mask Tensor
to pass around, but its size is 4x smaller than the input tensor. A
future optimization would be to pack the mask bitwise and unpack in the
backward.
2. the computation of `qval` can be done only once in the forward and
reused in the backward. No perf change observed, TODO verify with better
matrics.
TODO: describe in more detail
Test Plan:
OSS / torchvision / MobileNetV2
```
python references/classification/train_quantization.py
--print-freq 1
--data-path /data/local/packages/ai-group.imagenet-256-smallest-side/prod/
--output-dir ~/nfs/pytorch_vision_tests/
--backend qnnpack
--epochs 5
TODO paste results here
```
TODO more
Imported from OSS
Reviewed By: ngimel
Differential Revision: D25918519
fbshipit-source-id: ec544ca063f984de0f765bf833f205c99d6c18b6