TorchDynamo: set output stride using eager output for cat (#89477)
For squeezenet1_1 and densenet121 model, the cat's post op is always convolution, for channels last path, the currently cat path always set the output format as contiguous format, but convolution's input requires channels last, there always has a memory copy before convolution. This PR use eaged model's output format to set the format to reduce the memory copy.
Before:
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/ik/cikrybpw4xhois4wll6h5afsswjrhpsb6gslcxrntzqtlyw2btey.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
const float* __restrict__ in_ptr2,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1,
float* __restrict__ out_ptr2)
{
#pragma GCC ivdep
for(long i0=0; i0<3; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<256; i1+=1)
{
{
{
auto tmp0 = in_ptr0[i0 + (3*i1)];
out_ptr0[i1 + (256*i0)] = tmp0;
}
}
}
}
#pragma GCC ivdep
for(long i0=0; i0<3; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<256; i1+=1)
{
{
{
auto tmp0 = in_ptr1[i0 + (3*i1)];
out_ptr1[i1 + (256*i0)] = tmp0;
}
}
}
}
#pragma GCC ivdep
for(long i0=0; i0<6; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<256; i1+=1)
{
{
{
auto tmp0 = in_ptr2[i1 + (256*i0)];
out_ptr2[i0 + (6*i1)] = tmp0;
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
buf2 = empty_strided((1, 6, 16, 16), (1536, 256, 16, 1), device='cpu', dtype=torch.float32)
buf0 = as_strided(buf2, (1, 3, 16, 16), (1536, 256, 16, 1)) # alias
buf1 = as_strided(buf2, (1, 3, 16, 16), (1536, 256, 16, 1), 768) # alias
buf3 = empty_strided((1, 6, 16, 16), (1536, 1, 96, 6), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg2_1.data_ptr()), c_void_p(arg3_1.data_ptr()), c_void_p(buf2.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf3.data_ptr()))
del arg2_1
del arg3_1
del buf0
del buf1
del buf2
buf4 = aten.convolution(buf3, arg0_1, arg1_1, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
assert_size_stride(buf4, (1, 3, 16, 16), (768, 1, 48, 3))
del arg0_1
del arg1_1
return (buf4, )
```
after:
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/ik/cikrybpw4xhois4wll6h5afsswjrhpsb6gslcxrntzqtlyw2btey.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1)
{
#pragma GCC ivdep
for(long i0=0; i0<256; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<3; i1+=1)
{
{
{
auto tmp0 = in_ptr0[i1 + (3*i0)];
out_ptr0[i1 + (6*i0)] = tmp0;
}
}
}
}
#pragma GCC ivdep
for(long i0=0; i0<256; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<3; i1+=1)
{
{
{
auto tmp0 = in_ptr1[i1 + (3*i0)];
out_ptr1[i1 + (6*i0)] = tmp0;
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
buf2 = empty_strided((1, 6, 16, 16), (1536, 1, 96, 6), device='cpu', dtype=torch.float32)
buf0 = as_strided(buf2, (1, 3, 16, 16), (1536, 1, 96, 6)) # alias
buf1 = as_strided(buf2, (1, 3, 16, 16), (1536, 1, 96, 6), 3) # alias
kernel_cpp_0(c_void_p(arg2_1.data_ptr()), c_void_p(arg3_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr()))
del arg2_1
del arg3_1
del buf0
del buf1
buf3 = aten.convolution(buf2, arg0_1, arg1_1, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
assert_size_stride(buf3, (1, 3, 16, 16), (768, 1, 48, 3))
del arg0_1
del arg1_1
return (buf3, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89477
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel