fake_quant: more memory efficient per-channel backward (#51255)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51255
This is the same as #50561, but for per-channel fake_quant.
TODO before land write up better
Memory and performance impact (MobileNetV2): TODO
Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc
* forward pass on cpu: 512ms -> 750ms (+46%)
* forward pass on cuda: 99ms -> 128ms (+30%)
* note: the overall performance impact to training jobs should be minimal, because this is used for weights, and relative importance of fq is dominated by fq'ing the activations
* note: we can optimize the perf in a future PR by reading once and writing twice
Test Plan:
```
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D26117721
fbshipit-source-id: 798b59316dff8188a1d0948e69adf9e5509e414c