inductor: don't do transpose vectoriztion if input ld depends on most inner var (#94493)
Fixed https://github.com/pytorch/pytorch/issues/94269.
For the following case:
```
**import torch
import torchvision
#import intel_extension_for_pytorch
import torch._dynamo
from torch._inductor import config
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
constant_pad_nd = x
# File: /home/xiaobing/miniconda3/envs/pytorch_te_binary/lib/python3.8/site-packages/timm/models/layers/halo_attn.py:195, code: kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size)
as_strided: f32[1, 384, 2, 20, 12] = torch.ops.aten.as_strided.default(constant_pad_nd, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]); constant_pad_nd = None
as_strided_1: f32[1, 384, 2, 2, 12, 12] = torch.ops.aten.as_strided.default(as_strided, [1, 384, 2, 2, 12, 12], [153600, 1, 61440, 3072, 7680, 384]); as_strided = None
# File: /home/xiaobing/miniconda3/envs/pytorch_te_binary/lib/python3.8/site-packages/timm/models/layers/halo_attn.py:197, code: kv = kv.reshape(
clone_1: f32[1, 384, 2, 2, 12, 12] = torch.ops.aten.clone.default(as_strided_1, memory_format = torch.contiguous_format); as_strided_1 = None
_unsafe_view_1: f32[8, 48, 4, 144] = torch.ops.aten._unsafe_view.default(clone_1, [8, 48, 4, 144]); clone_1 = None
permute_2: f32[8, 4, 144, 48] = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1]); _unsafe_view_1 = None
# File: /home/xiaobing/miniconda3/envs/pytorch_te_binary/lib/python3.8/site-packages/timm/models/layers/halo_attn.py:202, code: k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(permute_2, [16, 32], -1); permute_2 = None
getitem: f32[8, 4, 144, 16] = split_with_sizes[0]
getitem_1: f32[8, 4, 144, 32] = split_with_sizes[1]; split_with_sizes = None
permute_3: f32[8, 4, 16, 144] = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2]); getitem = None
expand_1: f32[8, 4, 16, 144] = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144]); permute_3 = None
clone_3: f32[8, 4, 16, 144] = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None
return clone_3
model = Model().eval()
opt_model = torch._dynamo.optimize('inductor')(model)
x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last)
ref = model(x)
with torch.no_grad():
for i in range(3):
out = opt_model(x)
print(torch.equal(ref, out))
```
The generated code before this PR is:
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/ni/cniims6nap7c5wars7cmtbjr3mw6b5cxyoyxmsu7ro2l5fkrwatl.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0)
{
{
#pragma GCC ivdep
for(long i0=0; i0<8; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<4; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<1; i2+=1)
{
#pragma GCC ivdep
for(long i3=0; i3<9; i3+=1)
{
float tmp0[16*16] __attribute__ ((aligned (16)));
at::vec::transpose_mxn<float,16,16>(in_ptr0 + (16*i2) + (48*i0) + (384*((16*i3) % 12)) + (3072*(i1 % 2)) + (7680*(((4*i3) / 3))) + (61440*(i1 / 2)), ((-7680)*(i3 / 12)) + ((-384)*(i3 % 12)) + (384*((1 + i3) % 12)) + (7680*(((1 + i3) / 12))), tmp0, 16);
for (long i2_inner = 0; i2_inner < 16; i2_inner++)
{
auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + 16*i2_inner);
tmp1.store(out_ptr0 + (16*i3) + (144*i2_inner) + (2304*i1) + (2304*i2) + (9216*i0));
}
}
#pragma GCC ivdep
for(long i3=144; i3<144; i3+=1)
{
for (long i2_inner = 0; i2_inner < 16; i2_inner++)
{
auto tmp0 = in_ptr0[i2_inner + (16*i2) + (48*i0) + (384*(i3 % 12)) + (3072*(i1 % 2)) + (7680*(i3 / 12)) + (61440*(i1 / 2))];
out_ptr0[i3 + (144*i2_inner) + (2304*i1) + (2304*i2) + (9216*i0)] = tmp0;
}
}
}
#pragma GCC ivdep
for(long i2=16; i2<16; i2+=1)
{
#pragma GCC ivdep
for(long i3=0; i3<144; i3+=1)
{
auto tmp0 = in_ptr0[i2 + (48*i0) + (384*(i3 % 12)) + (3072*(i1 % 2)) + (7680*(i3 / 12)) + (61440*(i1 / 2))];
out_ptr0[i3 + (144*i2) + (2304*i1) + (9216*i0)] = tmp0;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
buf0 = empty_strided((8, 4, 16, 144), (9216, 2304, 144, 1), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg0_1
return (buf0, )
```
After:
```
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_xiaobing/dm/cdmaihqxwe73zkb3he2zizktpq5uujetg2db26c3r4lgsmlx3b4c.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
float* __restrict__ out_ptr0)
{
{
#pragma GCC ivdep
for(long i0=0; i0<8; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<4; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<16; i2+=1)
{
#pragma GCC ivdep
for(long i3=0; i3<144; i3+=1)
{
auto tmp0 = in_ptr0[i2 + (48*i0) + (384*(i3 % 12)) + (3072*(i1 % 2)) + (7680*(i3 / 12)) + (61440*(i1 / 2))];
out_ptr0[i3 + (144*i2) + (2304*i1) + (9216*i0)] = tmp0;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
buf0 = empty_strided((8, 4, 16, 144), (9216, 2304, 144, 1), device='cpu', dtype=torch.float32)
kernel_cpp_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg0_1
return (buf0, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1, 384, 20, 20), (153600, 1, 7680, 384), device='cpu', dtype=torch.float32)
print_performance(lambda: call([arg0_1]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94493
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/EikanWang