pytorch
900f8886 - inductor: make as_strided support non-contiguous input and always fix it's input layout using eager stride (#92063)

Commit
2 years ago
inductor: make as_strided support non-contiguous input and always fix it's input layout using eager stride (#92063) GIven the following small case: ``` import torch import torch._dynamo class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): return torch.as_strided(x + 1, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680))+ 2 x = torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last) model= Model().eval() model = model.to(memory_format=torch.channels_last) ref = model(x) with torch.no_grad(): opt_model = torch._dynamo.optimize('inductor')(model) with torch.no_grad(): for i in range(2): y1 = opt_model(x) print(torch.equal(ref, y1)) ``` inductor always gets a wrong result: ``` from ctypes import c_void_p, c_long import torch import random from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, float* __restrict__ out_ptr0, float* __restrict__ out_ptr1) { #pragma omp parallel num_threads(40) { { #pragma omp for for(long i0=0; i0<8; i0+=1) { #pragma GCC ivdep for(long i1=0; i1<384; i1+=1) { #pragma GCC ivdep for(long i2=0; i2<400; i2+=1) { auto tmp0 = in_ptr0[i1 + (384*i2) + (153600*i0)]; auto tmp1 = static_cast<float>(1); auto tmp2 = tmp0 + tmp1; out_ptr0[i2 + (400*i1) + (153600*i0)] = tmp2; } } } } { #pragma omp for collapse(2) for(long i0=0; i0<8; i0+=1) { for(long i1=0; i1<2; i1+=1) { for(long i2=0; i2<5760; i2+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + (16*i2) + (61440*i1) + (153600*i0)); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(2)); auto tmp2 = tmp0 + tmp1; tmp2.store(out_ptr1 + (16*i2) + (92160*i1) + (184320*i0)); } #pragma omp simd simdlen(8) for(long i2=92160; i2<92160; i2+=1) { auto tmp0 = out_ptr0[i2 + (61440*i1) + (153600*i0)]; auto tmp1 = static_cast<float>(2); auto tmp2 = tmp0 + tmp1; out_ptr1[i2 + (92160*i1) + (184320*i0)] = tmp2; } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() buf0 = empty_strided((8, 384, 20, 20), (153600, 400, 20, 1), device='cpu', dtype=torch.float32) buf1 = empty_strided((8, 384, 2, 20, 12), (184320, 1, 92160, 384, 7680), device='cpu', dtype=torch.float32) kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr())) del arg0_1 return (buf1, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((8, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1])) ``` the reason is that there always convert the input to a contiguous layout at **as_strided** lowering step, which is not aligned with the eager model input stride. ``` class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[8, 384, 20, 20]): # File: model_test.py:52, code: return torch.as_strided(x + 1, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680))+ 2 add: f32[8, 384, 20, 20] = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None as_strided: f32[8, 384, 2, 20, 12] = torch.ops.aten.as_strided.default(add, [8, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]); add = None add_1: f32[8, 384, 2, 20, 12] = torch.ops.aten.add.Tensor(as_strided, 2); as_strided = None return (add_1,) ``` This PR will always fix **as_strided** stride with eager model's stride, and also make **as_strided** support channels_last input: ``` from ctypes import c_void_p, c_long import torch import random from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, float* __restrict__ out_ptr0, float* __restrict__ out_ptr1) { #pragma omp parallel num_threads(40) { { #pragma omp for for(long i0=0; i0<76800; i0+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(1)); auto tmp2 = tmp0 + tmp1; tmp2.store(out_ptr0 + 16*i0); } #pragma omp for simd simdlen(8) for(long i0=1228800; i0<1228800; i0+=1) { auto tmp0 = in_ptr0[i0]; auto tmp1 = static_cast<float>(1); auto tmp2 = tmp0 + tmp1; out_ptr0[i0] = tmp2; } } { #pragma omp for collapse(2) for(long i0=0; i0<8; i0+=1) { for(long i1=0; i1<2; i1+=1) { for(long i2=0; i2<5760; i2+=1) { auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + (16*i2) + (61440*i1) + (153600*i0)); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(2)); auto tmp2 = tmp0 + tmp1; tmp2.store(out_ptr1 + (16*i2) + (92160*i1) + (184320*i0)); } #pragma omp simd simdlen(8) for(long i2=92160; i2<92160; i2+=1) { auto tmp0 = out_ptr0[i2 + (61440*i1) + (153600*i0)]; auto tmp1 = static_cast<float>(2); auto tmp2 = tmp0 + tmp1; out_ptr1[i2 + (92160*i1) + (184320*i0)] = tmp2; } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() buf0 = empty_strided((8, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32) buf1 = empty_strided((8, 384, 2, 20, 12), (184320, 1, 92160, 384, 7680), device='cpu', dtype=torch.float32) kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr())) del arg0_1 return (buf1, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((8, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1])) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/92063 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading