pytorch
2555971b - [inductor] fix output_stride of cat (#91233)

Commit
1 year ago
[inductor] fix output_stride of cat (#91233) When the inputs to the ConcatKernel come from both ExternKernel and Loops, the output format of Loops might still be a FlexibleLayout (with contiguous strides). When deciding the output stride of the ConcatKernel, the Loops output has been wrongly assumed to be contiguous, thus the output format of the ConcatKernel is set to be contiguous. In this PR, we propose the below heuristics to decide the output of the ConcatKernel: If any of the inputs to ConcatKernel is a FixedLayout and is in the channels last format, we set the output of the ConcatKernel to the channels last format as well. ### Before ```python kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, const float* __restrict__ in_ptr1, float* __restrict__ out_ptr0, float* __restrict__ out_ptr1) { #pragma omp parallel num_threads(56) { #pragma omp for collapse(2) for(long i0=0; i0<5; i0+=1) { for(long i1=0; i1<256; i1+=1) { { { auto tmp0 = in_ptr0[i0 + (5*i1)]; out_ptr0[i1 + (256*i0)] = tmp0; } } } } #pragma omp for collapse(2) for(long i0=0; i0<64; i0+=1) { for(long i1=0; i1<16; i1+=1) { #pragma GCC ivdep for(long i2=0; i2<16; i2+=1) { { { auto tmp0 = in_ptr1[i0 + (128*i2) + (4096*i1)]; auto tmp1 = in_ptr1[64 + i0 + (128*i2) + (4096*i1)]; auto tmp3 = in_ptr1[2048 + i0 + (128*i2) + (4096*i1)]; auto tmp5 = in_ptr1[2112 + i0 + (128*i2) + (4096*i1)]; auto tmp2 = (tmp0 != tmp0) ? tmp0 : std::max(tmp1, tmp0); auto tmp4 = (tmp2 != tmp2) ? tmp2 : std::max(tmp3, tmp2); auto tmp6 = (tmp4 != tmp4) ? tmp4 : std::max(tmp5, tmp4); out_ptr1[i2 + (16*i1) + (256*i0)] = tmp6; } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): primals_1, primals_2, primals_3, primals_4 = args args.clear() buf0 = aten.convolution(primals_3, primals_1, primals_2, (1, 1), (0, 0), (1, 1), False, (0, 0), 1) assert_size_stride(buf0, (1, 5, 16, 16), (1280, 1, 80, 5)) del primals_2 buf3 = empty_strided((1, 69, 16, 16), (17664, 256, 16, 1), device='cpu', dtype=torch.float32) buf1 = as_strided(buf3, (1, 5, 16, 16), (17664, 256, 16, 1)) # alias buf2 = as_strided(buf3, (1, 64, 16, 16), (17664, 256, 16, 1), 1280) # alias kernel_cpp_0(c_void_p(buf0.data_ptr()), c_void_p(primals_4.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr())) del buf0 del primals_4 return (buf3, primals_1, primals_3, ) ``` ### After ```python kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_chunyuan/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, const float* __restrict__ in_ptr1, float* __restrict__ out_ptr0, float* __restrict__ out_ptr1) { #pragma omp parallel num_threads(56) { #pragma omp for for(long i0=0; i0<256; i0+=1) { #pragma GCC ivdep for(long i1=0; i1<5; i1+=1) { { { auto tmp0 = in_ptr0[i1 + (5*i0)]; out_ptr0[i1 + (69*i0)] = tmp0; } } } } #pragma omp for collapse(2) for(long i0=0; i0<16; i0+=1) { for(long i1=0; i1<16; i1+=1) { for(long i2=0; i2<4; i2+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + (16*i2) + (128*i1) + (4096*i0)); auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 64 + (16*i2) + (128*i1) + (4096*i0)); auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr1 + 2048 + (16*i2) + (128*i1) + (4096*i0)); auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr1 + 2112 + (16*i2) + (128*i1) + (4096*i0)); auto tmp2 = at::vec::maximum(tmp1, tmp0); auto tmp4 = at::vec::maximum(tmp3, tmp2); auto tmp6 = at::vec::maximum(tmp5, tmp4); tmp6.store(out_ptr1 + (16*i2) + (69*i1) + (1104*i0)); } #pragma omp simd simdlen(8) for(long i2=64; i2<64; i2+=1) { auto tmp0 = in_ptr1[i2 + (128*i1) + (4096*i0)]; auto tmp1 = in_ptr1[64 + i2 + (128*i1) + (4096*i0)]; auto tmp3 = in_ptr1[2048 + i2 + (128*i1) + (4096*i0)]; auto tmp5 = in_ptr1[2112 + i2 + (128*i1) + (4096*i0)]; auto tmp2 = (tmp0 != tmp0) ? tmp0 : std::max(tmp1, tmp0); auto tmp4 = (tmp2 != tmp2) ? tmp2 : std::max(tmp3, tmp2); auto tmp6 = (tmp4 != tmp4) ? tmp4 : std::max(tmp5, tmp4); out_ptr1[i2 + (69*i1) + (1104*i0)] = tmp6; } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): primals_1, primals_2, primals_3, primals_4 = args args.clear() buf0 = aten.convolution(primals_3, primals_1, primals_2, (1, 1), (0, 0), (1, 1), False, (0, 0), 1) assert_size_stride(buf0, (1, 5, 16, 16), (1280, 1, 80, 5)) del primals_2 buf3 = empty_strided((1, 69, 16, 16), (17664, 1, 1104, 69), device='cpu', dtype=torch.float32) buf1 = as_strided(buf3, (1, 5, 16, 16), (17664, 1, 1104, 69)) # alias buf2 = as_strided(buf3, (1, 64, 16, 16), (17664, 1, 1104, 69), 5) # alias kernel_cpp_0(c_void_p(buf0.data_ptr()), c_void_p(primals_4.data_ptr()), c_void_p(buf1.data_ptr()), c_void_p(buf2.data_ptr())) del buf0 del primals_4 return (buf3, primals_1, primals_3, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91233 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
Author
Committer
Parents
Loading