[inductor] fix output_stride of cat (#91233)
When the inputs to the ConcatKernel come from both ExternKernel and Loops, the output format of Loops might still be a FlexibleLayout (with contiguous strides). When deciding the output stride of the ConcatKernel, the Loops output has been wrongly assumed to be contiguous, thus the output format of the ConcatKernel is set to be contiguous.
In this PR, we propose the below heuristics to decide the output of the ConcatKernel:
If any of the inputs to ConcatKernel is a FixedLayout and is in the channels last format, we set the output of the ConcatKernel to the channels last format as well.
### Before
```python
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1)
{
#pragma omp parallel num_threads(56)
{
#pragma omp for collapse(2)
for(long i0=0; i0<5; i0+=1)
{
for(long i1=0; i1<256; i1+=1)
{
{
{
auto tmp0 = in_ptr0[i0 + (5*i1)];
out_ptr0[i1 + (256*i0)] = tmp0;
}
}
}
}
#pragma omp for collapse(2)
for(long i0=0; i0<64; i0+=1)
{
for(long i1=0; i1<16; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<16; i2+=1)
{
{
{
auto tmp0 = in_ptr1[i0 + (128*i2) + (4096*i1)];
auto tmp1 = in_ptr1[64 + i0 + (128*i2) + (4096*i1)];
auto tmp3 = in_ptr1[2048 + i0 + (128*i2) + (4096*i1)];
auto tmp5 = in_ptr1[2112 + i0 + (128*i2) + (4096*i1)];
auto tmp2 = (tmp0 != tmp0) ? tmp0 : std::max(tmp1, tmp0);
auto tmp4 = (tmp2 != tmp2) ? tmp2 : std::max(tmp3, tmp2);
auto tmp6 = (tmp4 != tmp4) ? tmp4 : std::max(tmp5, tmp4);
out_ptr1[i2 + (16*i1) + (256*i0)] = tmp6;
}
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, primals_4 = args
args.clear()
buf0 = aten.convolution(primals_3, primals_1, primals_2, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
assert_size_stride(buf0, (1, 5, 16, 16), (1280, 1, 80, 5))
del primals_2
buf3 = empty_strided((1, 69, 16, 16), (17664, 256, 16, 1), device='cpu', dtype=torch.float32)
buf1 = as_strided(buf3, (1, 5, 16, 16), (17664, 256, 16, 1)) # alias
buf2 = as_strided(buf3, (1, 64, 16, 16), (17664, 256, 16, 1), 1280) # alias
kernel_cpp_0(c_void_p(buf0.data_ptr()), c_void_p(primals_4.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr()))
del buf0
del primals_4
return (buf3, primals_1, primals_3, )
```
### After
```python
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1)
{
#pragma omp parallel num_threads(56)
{
#pragma omp for
for(long i0=0; i0<256; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<5; i1+=1)
{
{
{
auto tmp0 = in_ptr0[i1 + (5*i0)];
out_ptr0[i1 + (69*i0)] = tmp0;
}
}
}
}
#pragma omp for collapse(2)
for(long i0=0; i0<16; i0+=1)
{
for(long i1=0; i1<16; i1+=1)
{
for(long i2=0; i2<4; i2+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + (16*i2) + (128*i1) + (4096*i0));
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 64 + (16*i2) + (128*i1) + (4096*i0));
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr1 + 2048 + (16*i2) + (128*i1) + (4096*i0));
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr1 + 2112 + (16*i2) + (128*i1) + (4096*i0));
auto tmp2 = at::vec::maximum(tmp1, tmp0);
auto tmp4 = at::vec::maximum(tmp3, tmp2);
auto tmp6 = at::vec::maximum(tmp5, tmp4);
tmp6.store(out_ptr1 + (16*i2) + (69*i1) + (1104*i0));
}
#pragma omp simd simdlen(8)
for(long i2=64; i2<64; i2+=1)
{
auto tmp0 = in_ptr1[i2 + (128*i1) + (4096*i0)];
auto tmp1 = in_ptr1[64 + i2 + (128*i1) + (4096*i0)];
auto tmp3 = in_ptr1[2048 + i2 + (128*i1) + (4096*i0)];
auto tmp5 = in_ptr1[2112 + i2 + (128*i1) + (4096*i0)];
auto tmp2 = (tmp0 != tmp0) ? tmp0 : std::max(tmp1, tmp0);
auto tmp4 = (tmp2 != tmp2) ? tmp2 : std::max(tmp3, tmp2);
auto tmp6 = (tmp4 != tmp4) ? tmp4 : std::max(tmp5, tmp4);
out_ptr1[i2 + (69*i1) + (1104*i0)] = tmp6;
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, primals_4 = args
args.clear()
buf0 = aten.convolution(primals_3, primals_1, primals_2, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
assert_size_stride(buf0, (1, 5, 16, 16), (1280, 1, 80, 5))
del primals_2
buf3 = empty_strided((1, 69, 16, 16), (17664, 1, 1104, 69), device='cpu', dtype=torch.float32)
buf1 = as_strided(buf3, (1, 5, 16, 16), (17664, 1, 1104, 69)) # alias
buf2 = as_strided(buf3, (1, 64, 16, 16), (17664, 1, 1104, 69), 5) # alias
kernel_cpp_0(c_void_p(buf0.data_ptr()), c_void_p(primals_4.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr()))
del buf0
del primals_4
return (buf3, primals_1, primals_3, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91233
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel