pytorch
9b168a1f - [TensorExpr] Pick meaningful names for functions in TE codegen. (#47255)

Commit
4 years ago
[TensorExpr] Pick meaningful names for functions in TE codegen. (#47255) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47255 As a result of this change, the generated CUDA code for the following fusion group: ``` graph(%0 : Float(32, 32, 1, 1, strides=[32, 1, 1, 1], requires_grad=0, device=cuda:0), %1 : Float(32, 32, strides=[32, 1], requires_grad=0, device=cuda:0), %2 : Float(32, 32, 1, strides=[32, 1, 1], requires_grad=0, device=cuda:0)): %3 : int = prim::Constant[value=1]() %v1.1 : Float(32, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cuda:0) = aten::add(%1, %2, %3) # test/test_tensorexpr.py:155:0 %5 : int = prim::Constant[value=1]() %6 : Float(32, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=0, device=cuda:0) = aten::add(%v1.1, %0, %5) # test/test_tensorexpr.py:156:0 return (%6) ``` Would look like the following: ``` extern "C" __global__ void fused_add_add(float* t0, float* t1, float* t2, float* aten_add) { { float v = __ldg(t1 + 32 * (((512 * blockIdx.x + threadIdx.x) / 32) % 32) + (512 * blockIdx.x + threadIdx.x) % 32); float v_1 = __ldg(t2 + ((512 * blockIdx.x + threadIdx.x) / 32) % 32 + 32 * (((512 * blockIdx.x + threadIdx.x) / 1024) % 32)); float v_2 = __ldg(t0 + ((512 * blockIdx.x + threadIdx.x) / 1024) % 32 + 32 * ((512 * blockIdx.x + threadIdx.x) / 32768)); aten_add[((((512 * blockIdx.x + threadIdx.x) / 32768) * 32768 + 32 * (((512 * blockIdx.x + threadIdx.x) / 32) % 32)) + 1024 * (((512 * blockIdx.x + threadIdx.x) / 1024) % 32)) + (512 * blockIdx.x + threadIdx.x) % 32] = (v + v_1) + v_2; } } ``` Previously we generated: ``` extern "C" __global__ void func(float* t0, float* t1, float* t2, float* aten_add) { { float v = __ldg(t1 + 32 * (((512 * blockIdx.x + threadIdx.x) / 32) % 32) + (512 * blockIdx.x + threadIdx.x) % 32); float v_1 = __ldg(t2 + ((512 * blockIdx.x + threadIdx.x) / 32) % 32 + 32 * (((512 * blockIdx.x + threadIdx.x) / 1024) % 32)); float v_2 = __ldg(t0 + ((512 * blockIdx.x + threadIdx.x) / 1024) % 32 + 32 * ((512 * blockIdx.x + threadIdx.x) / 32768)); aten_add[((((512 * blockIdx.x + threadIdx.x) / 32768) * 32768 + 32 * (((512 * blockIdx.x + threadIdx.x) / 32) % 32)) + 1024 * (((512 * blockIdx.x + threadIdx.x) / 1024) % 32)) + (512 * blockIdx.x + threadIdx.x) % 32] = (v + v_1) + v_2; } } ``` Differential Revision: D24698273 Test Plan: Imported from OSS Reviewed By: bertmaher Pulled By: ZolotukhinM fbshipit-source-id: 6da95c6ac3d5155ebfaaab4f84f55a24deb6d10d
Author
Mikhail Zolotukhin
Parents
Loading