pytorch
e59db086 - inductor: eliminate meaningless copy (#102089)

Commit
1 year ago
inductor: eliminate meaningless copy (#102089) This pr aims to eliminate meaningless load/store pairs in generate code. HF models on CPU are expected to gain 2~4% E2E training performance improvement. Taking the following case as an example, the generated kernel named cpp_fused_permute_1 does nothing but load and store in_out_ptr0. Example code: ``` @torch._dynamo.optimize("inductor") def fn(permute_6, view_10): permute_5 = torch.ops.aten.permute.default(view_10, [0, 2, 1, 3]) clone_2 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) view_11 = torch.ops.aten.view.default(clone_2, [1024, -1, 32]) bmm = torch.ops.aten.bmm.default(view_11, permute_6) permute_339 = torch.ops.aten.permute.default(view_11, [0, 2, 1]) return (bmm, permute_339) permute_6 = rand_strided((1024, 32, 128), (4096, 1, 32), device='cpu', dtype=torch.float32) view_10 = rand_strided((64, 128, 16, 32), (65536, 512, 32, 1), device='cpu', dtype=torch.float32) out = fn(permute_6, view_10) ``` Output code (Before this pr): ``` aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() cpp_fused_clone_0 = async_compile.cpp(''' #include "/tmp/torchinductor_bzheng/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr0) { #pragma omp parallel num_threads(80) { { #pragma omp for for(long i0=static_cast<long>(0L); i0<static_cast<long>(64L); i0+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i1=static_cast<long>(0L); i1<static_cast<long>(16L); i1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i2=static_cast<long>(0L); i2<static_cast<long>(128L); i2+=static_cast<long>(1L)) { for(long i3=static_cast<long>(0L); i3<static_cast<long>(32L); i3+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i3 + (32L*i1) + (512L*i2) + (65536L*i0))); tmp0.store(out_ptr0 + static_cast<long>(i3 + (32L*i2) + (4096L*i1) + (65536L*i0))); } } } } } } } ''') cpp_fused_permute_1 = async_compile.cpp(''' #include "/tmp/torchinductor_bzheng/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h" extern "C" void kernel(float* in_out_ptr0) { #pragma omp parallel num_threads(80) { { #pragma omp for for(long i0=static_cast<long>(0L); i0<static_cast<long>(4194304L); i0+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<long>(i0)); tmp0.store(in_out_ptr0 + static_cast<long>(i0)); } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1 = args args.clear() buf0 = empty_strided((64, 16, 128, 32), (65536, 4096, 32, 1), device='cpu', dtype=torch.float32) cpp_fused_clone_0(c_void_p(arg1_1.data_ptr()), c_void_p(buf0.data_ptr())) del arg1_1 buf1 = empty_strided((1024, 128, 128), (16384, 128, 1), device='cpu', dtype=torch.float32) extern_kernels.bmm(as_strided(buf0, (1024, 128, 32), (4096, 32, 1)), arg0_1, out=buf1) del arg0_1 buf2 = as_strided(buf0, (1024, 32, 128), (4096, 1, 32)); del buf0 # reuse cpp_fused_permute_1(c_void_p(buf2.data_ptr())) return (buf1, buf2, ) ``` Output code (After this pr): ``` aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() cpp_fused_clone_0 = async_compile.cpp(''' #include "/tmp/torchinductor_bzheng/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr0) { #pragma omp parallel num_threads(80) { { #pragma omp for for(long i0=static_cast<long>(0L); i0<static_cast<long>(64L); i0+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i1=static_cast<long>(0L); i1<static_cast<long>(16L); i1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i2=static_cast<long>(0L); i2<static_cast<long>(128L); i2+=static_cast<long>(1L)) { for(long i3=static_cast<long>(0L); i3<static_cast<long>(32L); i3+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i3 + (32L*i1) + (512L*i2) + (65536L*i0))); tmp0.store(out_ptr0 + static_cast<long>(i3 + (32L*i2) + (4096L*i1) + (65536L*i0))); } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1 = args args.clear() buf0 = empty_strided((64, 16, 128, 32), (65536, 4096, 32, 1), device='cpu', dtype=torch.float32) cpp_fused_clone_0(c_void_p(arg1_1.data_ptr()), c_void_p(buf0.data_ptr())) del arg1_1 buf1 = empty_strided((1024, 128, 128), (16384, 128, 1), device='cpu', dtype=torch.float32) extern_kernels.bmm(as_strided(buf0, (1024, 128, 32), (4096, 32, 1)), arg0_1, out=buf1) del arg0_1 return (buf1, as_strided(buf0, (1024, 32, 128), (4096, 1, 32)), ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/102089 Approved by: https://github.com/jgong5, https://github.com/ngimel
Author
Committer
Parents
Loading