pytorch
4a27d2be - Enabling intra-op parallelism for fbgemm_linear_int8_weight_fp32_activation op (#29532)

Commit
5 years ago
Enabling intra-op parallelism for fbgemm_linear_int8_weight_fp32_activation op (#29532) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29532 As we are migrating from `torch.jit.quantized` to `torch.quantization.quantize_dynamic` API, we still need to temporarily add the intra-op parallelism support in the legacy ` fbgemm_linear_int8_weight_fp32_activation` API for the parallelization of RNN operators and help the performance debugging for some legacy serialized models with the old API. ``` from __future__ import absolute_import, division, print_function, unicode_literals import time import torch K, N = 1024, 1024 print("M, nthread=1, nthread=2, nthread=4, nthread=8, nthread=16") for M in (2, 20, 200, 500, 1024,): print(M, sep=",", end=", ") for num_threads in (1, 2, 4, 8, 16): torch.set_num_threads(num_threads) x = torch.rand(M, K) w = torch.rand(K, N) b = torch.rand(N) NITER = 20 W_int8, col_offsets, W_scale, W_zp = torch.fbgemm_linear_quantize_weight(w) W_prepack = torch.fbgemm_pack_quantized_matrix(W_int8, W_int8.size(1), W_int8.size(0)) s = time.time() for _ in range(NITER): Y_fp32 = torch.fbgemm_linear_int8_weight(x, w, W_prepack, col_offsets, W_scale, W_zp, b) elapsed_per_iter_dyn_quant = (time.time() - s) / NITER print( "{:0.2f}".format(2.0 * M * N * K / elapsed_per_iter_dyn_quant / 1e9), end=", ", ) print("\n", end="") ``` On SKL T1 server: Before the Diff: ``` [root@rtptest33418.frc2 ~/jhuang_test]# ./torch_fbgemm_linear_int8_weight_fp32_activation.par M, nthread=1, nthread=2, nthread=4, nthread=8, nthread=16 2, 41.01, 51.51, 51.63, 51.49, 52.10, 20, 80.94, 81.43, 82.35, 82.27, 82.24, 200, 87.94, 87.61, 88.53, 88.43, 88.52, 500, 88.76, 89.60, 89.80, 89.65, 89.76, 1024, 88.01, 89.58, 90.11, 90.39, 89.96, ``` After the Diff: ``` [root@rtptest33418.frc2 ~/jhuang_test]# ./torch_fbgemm_linear_int8_weight_fp32_activation.par M, nthread=1, nthread=2, nthread=4, nthread=8, nthread=16 2, 45.08, 70.38, 72.22, 61.59, 44.15, 20, 83.09, 137.86, 205.58, 254.19, 201.08, 200, 87.86, 157.85, 287.24, 420.26, 476.16, 500, 88.57, 162.19, 296.52, 500.91, 530.25, 1024, 88.34, 147.47, 296.78, 534.45, 482.10, ``` ghstack-source-id: 93666880 Test Plan: CI Differential Revision: D18421371 fbshipit-source-id: 22cc1031ec9ee914c1508ba2aa9ed0281dfcd076
Author
Parents
Loading