pytorch
3790b505 - inductor: fix .to(memort_format) issue which doesn't generate right stride (#91948)

Commit
2 years ago
inductor: fix .to(memort_format) issue which doesn't generate right stride (#91948) Motivation: for **.to(memory_format),** the inductor doesn't generate the right stride, see the following example: ``` class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): x = x.to(memory_format=torch.contiguous_format) return x ``` the generated code doesn't do the memory format change and gets a wrong stride **(802816, 1, 14336, 256)**, it is not a contiguous stride. ``` from ctypes import c_void_p, c_long import torch import random from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() return (arg0_1, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((128, 256, 56, 56), (802816, 1, 14336, 256), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1])) ``` After this PR, the will have a memory format change: ``` from ctypes import c_void_p, c_long import torch import random from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h" extern "C" void kernel(const float* __restrict__ in_ptr0, float* __restrict__ out_ptr0) { #pragma omp parallel num_threads(40) { { #pragma omp for for(long i0=0; i0<128; i0+=1) { #pragma GCC ivdep for(long i1=0; i1<256; i1+=1) { #pragma GCC ivdep for(long i2=0; i2<3136; i2+=1) { auto tmp0 = in_ptr0[i1 + (256*i2) + (802816*i0)]; out_ptr0[i2 + (3136*i1) + (802816*i0)] = tmp0; } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() buf1 = empty_strided((128, 256, 56, 56), (802816, 3136, 56, 1), device='cpu', dtype=torch.float32) kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf1.data_ptr())) del arg0_1 return (buf1, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((128, 256, 56, 56), (802816, 1, 14336, 256), device='cpu', dtype=torch.float32) print_performance(lambda: call([arg0_1])) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91948 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading