quantized torch.topk (#26486)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26486
This PR adds a quantized version of torch.topk, supporting all the same options
Benchmark script
```
import torch
import time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
X = torch.rand(6, 5, 1024)
qX = torch.quantize_linear(X, 0.01, 0, dtype)
X = qX.dequantize()
NITER = 10000
s = time.time()
for i in range(NITER):
float_out = torch.topk(X, 50)
float_time_per_iter = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.topk(qX, 50)
quant_time_per_iter = (time.time() - s) / NITER
print(dtype)
print('float ms', 'quant ms', 'float gB/s', 'quant gB/s', sep='\t')
nbytes_float = (X.numel() + float_out[0].numel()) * X.element_size()
nbytes_quant = (qX.numel() + quant_out[0].numel()) * qX.element_size()
print(float_time_per_iter * 1000,
quant_time_per_iter * 1000,
nbytes_float / float_time_per_iter / 1e9,
nbytes_quant / quant_time_per_iter / 1e9, sep='\t')
```
Results
```
torch.qint8
float ms quant ms float gB/s quant gB/s
0.3706729888916016 0.3370296716690064 0.34769191136743244 0.09559989136992947
torch.quint8
float ms quant ms float gB/s quant gB/s
0.38260042667388916 0.34079675674438475 0.3368527346412275 0.09454315325003715
torch.qint32
float ms quant ms float gB/s quant gB/s
0.38033516407012935 0.3364055633544922 0.3388590174539739 0.38310900305828427
```
Test Plan: Imported from OSS
Differential Revision: D17529988
Pulled By: jamesr66a
fbshipit-source-id: b5edfe90c592b6c84459d1c0c77e4c18f5b04417