pytorch
b0479506 - Add the 3d avg pool for video related model (#33339)

Commit
6 years ago
Add the 3d avg pool for video related model (#33339) Summary: ``` import torch, time for dtype in [torch.qint8, torch.quint8, torch.qint32]: print('****', str(dtype), '*****') x = torch.rand(1, 5, 56, 56, 256) q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype) q_x = q_x.permute([0, 4, 1, 2, 3]) x = x.permute([0, 4, 1, 2, 3]) NITER = 10 s = time.time() for i in range(NITER): float_out = torch.nn.functional.avg_pool3d(x, kernel_size=3, stride=None, padding=0) time_per_iter_float = (time.time() - s) / NITER s = time.time() for i in range(NITER): quant_out = torch.nn.quantized.functional.avg_pool3d(q_x, kernel_size=3, stride=None, padding=0) time_per_iter_quant = (time.time() - s) / NITER print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t') print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t') ``` ``` **** torch.qint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 16.286182403564453 0.7308721542358398 0.04487682479080417 **** torch.quint8 ***** time/iter ms (float) time/iter ms (quant) quant/float 15.364313125610352 0.6497383117675781 0.042288796541418254 **** torch.qint32 ***** time/iter ms (float) time/iter ms (quant) quant/float 15.649032592773438 13.879132270812988 0.8869003363966556 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/33339 Differential Revision: D19900904 Pulled By: lly-zero-one fbshipit-source-id: 4522cc6b4a0751aeda6c7edc258e0cb3f55a8fe3
Author
Parents
Loading