inductor: eliminate meaningless copy (#102089)
This pr aims to eliminate meaningless load/store pairs in generate code. HF models on CPU are expected to gain 2~4% E2E training performance improvement.
Taking the following case as an example, the generated kernel named cpp_fused_permute_1 does nothing but load and store in_out_ptr0.
Example code:
```
@torch._dynamo.optimize("inductor")
def fn(permute_6, view_10):
permute_5 = torch.ops.aten.permute.default(view_10, [0, 2, 1, 3])
clone_2 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format)
view_11 = torch.ops.aten.view.default(clone_2, [1024, -1, 32])
bmm = torch.ops.aten.bmm.default(view_11, permute_6)
permute_339 = torch.ops.aten.permute.default(view_11, [0, 2, 1])
return (bmm, permute_339)
permute_6 = rand_strided((1024, 32, 128), (4096, 1, 32), device='cpu', dtype=torch.float32)
view_10 = rand_strided((64, 128, 16, 32), (65536, 512, 32, 1), device='cpu', dtype=torch.float32)
out = fn(permute_6, view_10)
```
Output code (Before this pr):
```
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
cpp_fused_clone_0 = async_compile.cpp('''
#include "/tmp/torchinductor_bzheng/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
#pragma omp parallel num_threads(80)
{
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(64L); i0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(16L); i1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i2=static_cast<long>(0L); i2<static_cast<long>(128L); i2+=static_cast<long>(1L))
{
for(long i3=static_cast<long>(0L); i3<static_cast<long>(32L); i3+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i3 + (32L*i1) + (512L*i2) + (65536L*i0)));
tmp0.store(out_ptr0 + static_cast<long>(i3 + (32L*i2) + (4096L*i1) + (65536L*i0)));
}
}
}
}
}
}
}
''')
cpp_fused_permute_1 = async_compile.cpp('''
#include "/tmp/torchinductor_bzheng/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
extern "C" void kernel(float* in_out_ptr0)
{
#pragma omp parallel num_threads(80)
{
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(4194304L); i0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<long>(i0));
tmp0.store(in_out_ptr0 + static_cast<long>(i0));
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
buf0 = empty_strided((64, 16, 128, 32), (65536, 4096, 32, 1), device='cpu', dtype=torch.float32)
cpp_fused_clone_0(c_void_p(arg1_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg1_1
buf1 = empty_strided((1024, 128, 128), (16384, 128, 1), device='cpu', dtype=torch.float32)
extern_kernels.bmm(as_strided(buf0, (1024, 128, 32), (4096, 32, 1)), arg0_1, out=buf1)
del arg0_1
buf2 = as_strided(buf0, (1024, 32, 128), (4096, 1, 32)); del buf0 # reuse
cpp_fused_permute_1(c_void_p(buf2.data_ptr()))
return (buf1, buf2, )
```
Output code (After this pr):
```
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
cpp_fused_clone_0 = async_compile.cpp('''
#include "/tmp/torchinductor_bzheng/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
#pragma omp parallel num_threads(80)
{
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(64L); i0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(16L); i1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i2=static_cast<long>(0L); i2<static_cast<long>(128L); i2+=static_cast<long>(1L))
{
for(long i3=static_cast<long>(0L); i3<static_cast<long>(32L); i3+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i3 + (32L*i1) + (512L*i2) + (65536L*i0)));
tmp0.store(out_ptr0 + static_cast<long>(i3 + (32L*i2) + (4096L*i1) + (65536L*i0)));
}
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
buf0 = empty_strided((64, 16, 128, 32), (65536, 4096, 32, 1), device='cpu', dtype=torch.float32)
cpp_fused_clone_0(c_void_p(arg1_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg1_1
buf1 = empty_strided((1024, 128, 128), (16384, 128, 1), device='cpu', dtype=torch.float32)
extern_kernels.bmm(as_strided(buf0, (1024, 128, 32), (4096, 32, 1)), arg0_1, out=buf1)
del arg0_1
return (buf1, as_strided(buf0, (1024, 32, 128), (4096, 1, 32)), )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102089
Approved by: https://github.com/jgong5, https://github.com/ngimel