pytorch
b154a8cf - Integrating the int64_t GEMM in FBGEMM into PyTorch Linear op (#30143)

Commit
6 years ago
Integrating the int64_t GEMM in FBGEMM into PyTorch Linear op (#30143) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30143 We would like to integrate the int64 GEMM in FBGEMM into PyTorch. This brings ~4x speedup for the Linear op with LongTensor. Benchmark code: ``` from __future__ import absolute_import, division, print_function, unicode_literals import time import torch torch.set_num_threads(1) print("M, N, K, GOPS/sec") for M in range(128, 1025, 128): N = M K = M x = torch.LongTensor(M, K) w = torch.LongTensor(K, N) NITER = 20 # Test torch.nn.functional.linear s = time.time() for _ in range(NITER): torch.nn.functional.linear(x, w) # Z = x @ w elapsed_per_iter_linear = (time.time() - s) / NITER print( "{}, {}, {}, {:0.2f}".format(M, N, K, 2.0 * M * N * K / elapsed_per_iter_linear / 1e9) ) ``` Before this PR: ``` M, N, K, GOPS/sec 128, 128, 128, 2.31 256, 256, 256, 2.49 384, 384, 384, 2.54 512, 512, 512, 2.57 640, 640, 640, 2.46 768, 768, 768, 2.59 896, 896, 896, 2.59 1024, 1024, 1024, 2.61 ``` After this PR: ``` (base) [root@rtptest10054.frc2 ~/jhuang_test/int64_gemm]# python torch_linear.py M, N, K, GOPS/sec 128, 128, 128, 5.35 256, 256, 256, 8.34 384, 384, 384, 9.03 512, 512, 512, 9.22 640, 640, 640, 9.55 768, 768, 768, 9.73 896, 896, 896, 9.82 1024, 1024, 1024, 9.63 ``` ghstack-source-id: 94308012 Test Plan: CI Reviewed By: jspark1105 Differential Revision: D18610019 fbshipit-source-id: f830660927b2666db34427d9de51db011f80f766
Author
Parents
Loading