inductor: align the decomposition output stride with none-decomposition path for torch.lerp (#93336)
As title, we need to align the decomposition output stride with the none-decomposition path for torch.lerp. And also enable it's lowering path for inductor.
After this PR for the following case:
```
def fn(i0, i1):
# i0: (10, 3, 10)
# i1: (3, 10, 10)
x1 = i0.transpose(-2, -3)
#y = torch.lerp(x1, x1, 70000)
z = torch.lerp(i1, x1, 70000)
return z
x0 = torch.rand(10, 3, 10)
x1 = torch.rand(3, 10, 10)
ret_eager = fn(x0, x1)
print('==== Eager mode OK! ====')
compiled = torch.compile(fn, fullgraph=True)
ret_compiled = compiled(x0, x1)
print('==== compile mode OK! ====')
ret_compiled = compiled(x0, x1)
print(torch.equal(ret_eager, ret_compiled))
print(ret_eager.stride()==ret_compiled.stride())
```
the inductor output code will be like(CPU):
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/77/c7773nj5pwikpmm2pwa62rcudlf7p3if7eyqb5k4sjsvewwje4le.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0)
{
{
#pragma GCC ivdep
for(long i0=0; i0<3; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<10; i1+=1)
{
for(long i2=0; i2<0; i2+=1)
{
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr0 + (10*i0) + (16*i2) + (30*i1));
auto tmp8 = at::vec::Vectorized<float>::loadu(in_ptr1 + (10*i1) + (16*i2) + (100*i0));
auto tmp0 = at::vec::Vectorized<float>(static_cast<float>(70000.0));
auto tmp1 = tmp0.abs();
auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(0.5));
auto tmp3 = tmp1 >= tmp2;
auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1));
auto tmp5 = tmp0 - tmp4;
auto tmp6 = decltype(tmp5)::blendv(tmp0, tmp5, tmp3);
auto tmp9 = tmp7 - tmp8;
auto tmp10 = tmp6 * tmp9;
auto tmp11 = decltype(tmp7)::blendv(tmp8, tmp7, tmp3);
auto tmp12 = tmp10 + tmp11;
tmp12.store(out_ptr0 + (10*i1) + (16*i2) + (100*i0));
}
#pragma omp simd simdlen(8)
for(long i2=0; i2<10; i2+=1)
{
auto tmp7 = in_ptr0[i2 + (10*i0) + (30*i1)];
auto tmp8 = in_ptr1[i2 + (10*i1) + (100*i0)];
auto tmp0 = static_cast<float>(70000.0);
auto tmp1 = std::abs(tmp0);
auto tmp2 = static_cast<float>(0.5);
auto tmp3 = tmp1 >= tmp2;
auto tmp4 = static_cast<float>(1);
auto tmp5 = tmp0 - tmp4;
auto tmp6 = tmp3 ? tmp5 : tmp0;
auto tmp9 = tmp7 - tmp8;
auto tmp10 = tmp6 * tmp9;
auto tmp11 = tmp3 ? tmp7 : tmp8;
auto tmp12 = tmp10 + tmp11;
out_ptr0[i2 + (10*i1) + (100*i0)] = tmp12;
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
buf1 = empty_strided((3, 10, 10), (100, 10, 1), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(buf1.data_ptr()))
del arg0_1
del arg1_1
return (buf1, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((10, 3, 10), (30, 10, 1), device='cpu', dtype=torch.float32)
arg1_1 = rand_strided((3, 10, 10), (100, 10, 1), device='cpu', dtype=torch.float32)
print_performance(lambda: call([arg0_1, arg1_1]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93336
Approved by: https://github.com/jansel