Add the quantized average_pool2d support and adaptive_avg_pool2d support (#25899)
Summary:
//copied from PR https://github.com/pytorch/pytorch/issues/25676
===============For avg_pool2d==============
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 56, 56, 256)
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 3, 1, 2])
x = x.permute([0, 3, 1, 2])
NITER = 100
s = time.time()
for i in range(NITER):
float_out = torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=None, padding=0)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.nn.quantized.functional.avg_pool2d(q_x, kernel_size=3, stride=None, padding=0)
time_per_iter_quant = (time.time() - s) / NITER
ref_quantized = torch.quantize_linear(float_out, 0.5, 1, dtype)
torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize())
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
bytes_float = (x.numel() + float_out.numel()) * x.element_size()
bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size()
float_bw_gbps = bytes_float / time_per_iter_float / 1e9
quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9
print('GB/s float', 'GB/s quant', sep='\t')
print(float_bw_gbps, quant_bw_gbps, sep='\t')
Before the vectorization:
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.67439603805542 7.126874923706055 2.6648539791017924
GB/s float GB/s quant
1.2470733401269298 0.11699265230915809
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.587001323699951 7.011299133300781 2.7102031487456535
GB/s float GB/s quant
1.2892022781148076 0.11892118481150399
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.6659250259399414 7.03080415725708 2.637285028215745
GB/s float GB/s quant
1.2510359321992184 0.4743650833393638
After the vectorization
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.6113319396972656 0.5631613731384277 0.2156605847679846
GB/s float GB/s quant
1.2771903676047593 1.48055608884072
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.5221967697143555 0.5518221855163574 0.21878633425529784
GB/s float GB/s quant
1.322326647963202 1.5109794819499591
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.5173258781433105 4.0132904052734375 1.5942673295177407
GB/s float GB/s quant
1.324885279636461 0.8310308159154421
===============For adaptive_avg_pool2d==============
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 56, 56, 256)
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 3, 1, 2])
x = x.permute([0, 3, 1, 2])
NITER = 100
s = time.time()
for i in range(NITER):
float_out = torch.nn.functional.adaptive_avg_pool2d(x, output_size=5)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.nn.quantized.functional.adaptive_avg_pool2d(q_x, output_size=5)
time_per_iter_quant = (time.time() - s) / NITER
ref_quantized = torch.quantize_linear(float_out, 0.5, 1, dtype)
torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize())
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
bytes_float = (x.numel() + float_out.numel()) * x.element_size()
bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size()
float_bw_gbps = bytes_float / time_per_iter_float / 1e9
quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9
print('GB/s float', 'GB/s quant', sep='\t')
print(float_bw_gbps, quant_bw_gbps, sep='\t')
~
//Before the vectorization
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.286238670349121 4.600362777709961 2.0121970804594342
GB/s float GB/s quant
1.4158031888707898 0.17590264922602994
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.2867274284362793 4.474163055419922 1.9565790831832832
GB/s float GB/s quant
1.4155005794518536 0.180864217503144
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.3176145553588867 4.264359474182129 1.8399778618588218
GB/s float GB/s quant
1.3966360335956578 0.7590504551966285
//After the vectorization:
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.3224568367004395 0.23195743560791016 0.09987588657942796
GB/s float GB/s quant
1.3937240722194333 3.4886400510473843
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.255082130432129 0.2124309539794922 0.09420098324258604
GB/s float GB/s quant
1.435364129899667 3.8093130254365883
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.266514301300049 1.6029787063598633 0.7072440290539581
GB/s float GB/s quant
1.4281242338260862 2.0192807222938463
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25899
Differential Revision: D17437015
Pulled By: llyfacebook
fbshipit-source-id: 496aed1e41711048d0853254d6819d3fb141a0c0