inductor: make as_strided support non-contiguous input and always fix it's input layout using eager stride (#92063)
GIven the following small case:
```
import torch
import torch._dynamo
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.as_strided(x + 1, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680))+ 2
x = torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)
model= Model().eval()
model = model.to(memory_format=torch.channels_last)
ref = model(x)
with torch.no_grad():
opt_model = torch._dynamo.optimize('inductor')(model)
with torch.no_grad():
for i in range(2):
y1 = opt_model(x)
print(torch.equal(ref, y1))
```
inductor always gets a wrong result:
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1)
{
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=0; i0<8; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<384; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<400; i2+=1)
{
auto tmp0 = in_ptr0[i1 + (384*i2) + (153600*i0)];
auto tmp1 = static_cast<float>(1);
auto tmp2 = tmp0 + tmp1;
out_ptr0[i2 + (400*i1) + (153600*i0)] = tmp2;
}
}
}
}
{
#pragma omp for collapse(2)
for(long i0=0; i0<8; i0+=1)
{
for(long i1=0; i1<2; i1+=1)
{
for(long i2=0; i2<5760; i2+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + (16*i2) + (61440*i1) + (153600*i0));
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(2));
auto tmp2 = tmp0 + tmp1;
tmp2.store(out_ptr1 + (16*i2) + (92160*i1) + (184320*i0));
}
#pragma omp simd simdlen(8)
for(long i2=92160; i2<92160; i2+=1)
{
auto tmp0 = out_ptr0[i2 + (61440*i1) + (153600*i0)];
auto tmp1 = static_cast<float>(2);
auto tmp2 = tmp0 + tmp1;
out_ptr1[i2 + (92160*i1) + (184320*i0)] = tmp2;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
buf0 = empty_strided((8, 384, 20, 20), (153600, 400, 20, 1), device='cpu', dtype=torch.float32)
buf1 = empty_strided((8, 384, 2, 20, 12), (184320, 1, 92160, 384, 7680), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr()))
del arg0_1
return (buf1, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((8, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32)
print_performance(lambda: call([arg0_1]))
```
the reason is that there always convert the input to a contiguous layout at **as_strided** lowering step, which is not aligned with the eager model input stride.
```
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[8, 384, 20, 20]):
# File: model_test.py:52, code: return torch.as_strided(x + 1, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680))+ 2
add: f32[8, 384, 20, 20] = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
as_strided: f32[8, 384, 2, 20, 12] = torch.ops.aten.as_strided.default(add, [8, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]); add = None
add_1: f32[8, 384, 2, 20, 12] = torch.ops.aten.add.Tensor(as_strided, 2); as_strided = None
return (add_1,)
```
This PR will always fix **as_strided** stride with eager model's stride, and also make **as_strided** support channels_last input:
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0,
float* __restrict__ out_ptr1)
{
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=0; i0<76800; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(1));
auto tmp2 = tmp0 + tmp1;
tmp2.store(out_ptr0 + 16*i0);
}
#pragma omp for simd simdlen(8)
for(long i0=1228800; i0<1228800; i0+=1)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = static_cast<float>(1);
auto tmp2 = tmp0 + tmp1;
out_ptr0[i0] = tmp2;
}
}
{
#pragma omp for collapse(2)
for(long i0=0; i0<8; i0+=1)
{
for(long i1=0; i1<2; i1+=1)
{
for(long i2=0; i2<5760; i2+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + (16*i2) + (61440*i1) + (153600*i0));
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(2));
auto tmp2 = tmp0 + tmp1;
tmp2.store(out_ptr1 + (16*i2) + (92160*i1) + (184320*i0));
}
#pragma omp simd simdlen(8)
for(long i2=92160; i2<92160; i2+=1)
{
auto tmp0 = out_ptr0[i2 + (61440*i1) + (153600*i0)];
auto tmp1 = static_cast<float>(2);
auto tmp2 = tmp0 + tmp1;
out_ptr1[i2 + (92160*i1) + (184320*i0)] = tmp2;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
buf0 = empty_strided((8, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32)
buf1 = empty_strided((8, 384, 2, 20, 12), (184320, 1, 92160, 384, 7680), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()), c_void_p(buf1.data_ptr()))
del arg0_1
return (buf1, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((8, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32)
print_performance(lambda: call([arg0_1]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92063
Approved by: https://github.com/jansel