Migrate fake_quant_slice to TensorIterator (#33744)
Summary:
This is a quick improvement for per tensor quantization.
per-channel should remove the loop in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp
# Benchmark:
device = GTX-1650
```python
import torch
print(torch.__version__)
for i in range(1000):
torch.randn(1024 * 128, device='cuda')
def f(e):
a = torch.randn(2 ** e, device='cuda')
torch.cuda.synchronize()
%timeit torch.fake_quantize_per_tensor_affine(a, 0.5, 0, 0, 1); torch.cuda.synchronize()
for i in range(15, 27):
f(i)
```
Before
```
1.5.0a0+bf00b4d
14.5 µs ± 981 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
18.2 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
25.6 µs ± 2.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
38.6 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
70.2 µs ± 5.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
125 µs ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
231 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
461 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
891 µs ± 88.2 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.77 ms ± 8.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.77 ms ± 80.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.16 ms ± 216 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
After
```
1.5.0a0+3f18ac3
12.5 µs ± 738 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
13.7 µs ± 195 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
17.9 µs ± 850 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
29.7 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
50.4 µs ± 1.94 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
95 µs ± 8.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
173 µs ± 7.37 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
348 µs ± 29.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
657 µs ± 22.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.33 ms ± 77.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.71 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.33 ms ± 439 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33744
Differential Revision: D20090129
Pulled By: ngimel
fbshipit-source-id: 5dd48a0c5455a2b6c5c638d747c1767cb259255d