TorchDynamo: always convert flexiblelayout to be FixedLayout when given a stride_order (#89904)
For convolution, we always call **require_stride_order** to convert the input to the target stride order, if the original input's layout is flexiblelayout, there always have a memory copy because the **is_stride_order_storage_and_layout** only checks the init stride order, I think for flexiblelayout, means it's layout can be changed, if the user gives a stride order, I think we always need to convert the flexiblelayout to be FixedLayout using given strider order.
Given a CV user case, the max_pooling's output is used by two convolutions, there has two memory copies:
```
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1,
float* __restrict__ out_ptr2)
{
#pragma GCC ivdep
for(long i0=0; i0<128; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<3; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<3; i2+=1)
{
#pragma GCC ivdep
for(long i3=0; i3<3; i3+=1)
{
{
{
auto tmp0 = in_ptr0[i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp1 = in_ptr0[3 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp3 = in_ptr0[6 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp5 = in_ptr0[21 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp7 = in_ptr0[24 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp9 = in_ptr0[27 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp11 = in_ptr0[42 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp13 = in_ptr0[45 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp15 = in_ptr0[48 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp2 = (tmp0 != tmp0) ? tmp0 : std::max(tmp1, tmp0);
auto tmp4 = (tmp2 != tmp2) ? tmp2 : std::max(tmp3, tmp2);
auto tmp6 = (tmp4 != tmp4) ? tmp4 : std::max(tmp5, tmp4);
auto tmp8 = (tmp6 != tmp6) ? tmp6 : std::max(tmp7, tmp6);
auto tmp10 = (tmp8 != tmp8) ? tmp8 : std::max(tmp9, tmp8);
auto tmp12 = (tmp10 != tmp10) ? tmp10 : std::max(tmp11, tmp10);
auto tmp14 = (tmp12 != tmp12) ? tmp12 : std::max(tmp13, tmp12);
auto tmp16 = (tmp14 != tmp14) ? tmp14 : std::max(tmp15, tmp14);
out_ptr0[i3 + (3*i2) + (9*i1) + (27*i0)] = tmp16;
}
}
}
}
}
}
#pragma GCC ivdep
for(long i0=0; i0<128; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<3; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<9; i2+=1)
{
{
{
auto tmp0 = out_ptr0[i1 + (3*i2) + (27*i0)];
out_ptr1[i1 + (3*i2) + (27*i0)] = tmp0;
out_ptr2[i1 + (3*i2) + (27*i0)] = tmp0;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
args.clear()
buf0 = empty_strided((128, 3, 3, 3), (27, 1, 9, 3), device='cpu', dtype=torch.float32)
buf2 = empty_strided((128, 3, 3, 3), (27, 1, 9, 3), device='cpu', dtype=torch.float32)
buf4 = empty_strided((128, 3, 3, 3), (27, 1, 9, 3), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg4_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(buf4.data_ptr()))
del arg4_1
del buf0
buf3 = torch.ops.mkldnn._convolution_pointwise(buf2, arg0_1, arg1_1, (0, 0), (1, 1), (1, 1), 1, 'none', [], '')
assert_size_stride(buf3, (128, 3, 3, 3), (27, 1, 9, 3))
del arg0_1
del arg1_1
del buf2
buf5 = torch.ops.mkldnn._convolution_pointwise(buf4, arg2_1, arg3_1, (0, 0), (1, 1), (1, 1), 1, 'none', [], '')
assert_size_stride(buf5, (128, 3, 3, 3), (27, 1, 9, 3))
del arg2_1
del arg3_1
return (buf3, buf5, )
```
After this PR, the generated code will remove the redundant memory copy:
```
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0)
{
#pragma GCC ivdep
for(long i0=0; i0<128; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<3; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<3; i2+=1)
{
#pragma GCC ivdep
for(long i3=0; i3<3; i3+=1)
{
{
{
auto tmp0 = in_ptr0[i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp1 = in_ptr0[3 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp3 = in_ptr0[6 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp5 = in_ptr0[21 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp7 = in_ptr0[24 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp9 = in_ptr0[27 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp11 = in_ptr0[42 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp13 = in_ptr0[45 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp15 = in_ptr0[48 + i3 + (6*i2) + (42*i1) + (147*i0)];
auto tmp2 = (tmp0 != tmp0) ? tmp0 : std::max(tmp1, tmp0);
auto tmp4 = (tmp2 != tmp2) ? tmp2 : std::max(tmp3, tmp2);
auto tmp6 = (tmp4 != tmp4) ? tmp4 : std::max(tmp5, tmp4);
auto tmp8 = (tmp6 != tmp6) ? tmp6 : std::max(tmp7, tmp6);
auto tmp10 = (tmp8 != tmp8) ? tmp8 : std::max(tmp9, tmp8);
auto tmp12 = (tmp10 != tmp10) ? tmp10 : std::max(tmp11, tmp10);
auto tmp14 = (tmp12 != tmp12) ? tmp12 : std::max(tmp13, tmp12);
auto tmp16 = (tmp14 != tmp14) ? tmp14 : std::max(tmp15, tmp14);
out_ptr0[i3 + (3*i2) + (9*i1) + (27*i0)] = tmp16;
}
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
args.clear()
buf0 = empty_strided((128, 3, 3, 3), (27, 1, 9, 3), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg4_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg4_1
buf2 = torch.ops.mkldnn._convolution_pointwise(buf0, arg0_1, arg1_1, (0, 0), (1, 1), (1, 1), 1, 'none', [], '')
assert_size_stride(buf2, (128, 3, 3, 3), (27, 1, 9, 3))
del arg0_1
del arg1_1
buf3 = torch.ops.mkldnn._convolution_pointwise(buf0, arg2_1, arg3_1, (0, 0), (1, 1), (1, 1), 1, 'none', [], '')
assert_size_stride(buf3, (128, 3, 3, 3), (27, 1, 9, 3))
del arg2_1
del arg3_1
return (buf2, buf3, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89904
Approved by: https://github.com/jansel