inductor: don't remember pre-loop order if pre loop has loop collapse (#96640)
Given the following case from timm **ese_vovnet19b_dw**:
```
import torch
import torch._dynamo
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = torch.nn.Conv2d(256, 256, kernel_size=1, padding=0)
self.conv2 = torch.nn.Conv2d(256, 256, kernel_size=1, padding=0)
self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
def forward(self, x):
x = self.conv1(x)
x2 = self.conv2(x)
y = x2 * x
return self.pool(y)
model = Model().to(memory_format=torch.channels_last).eval()
x = torch.randn(128, 256, 56, 56).to(memory_format=torch.channels_last)
opt_model = torch._dynamo.optimize('inductor')(model)
with torch.no_grad():
for i in range(2):
y = opt_model(x
```
before this PR, the max_pooling can't be vectorized:
```
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr0)
{
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=0; i0<6422528; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 16*i0);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + 16*i0);
}
#pragma omp for simd simdlen(8)
for(long i0=102760448; i0<102760448; i0+=1)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_out_ptr0[i0];
auto tmp2 = tmp0 * tmp1;
in_out_ptr0[i0] = tmp2;
}
}
{
#pragma omp for
for(long i0=0; i0<128; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<256; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<28; i2+=1)
{
#pragma GCC ivdep
for(long i3=0; i3<28; i3+=1)
{
auto tmp0 = static_cast<long>(2*i2);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(56);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = tmp2 & tmp4;
auto tmp6 = static_cast<long>(2*i3);
auto tmp7 = tmp6 >= tmp1;
auto tmp8 = tmp6 < tmp3;
auto tmp9 = tmp7 & tmp8;
auto tmp10 = tmp5 & tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_out_ptr0[i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp12;
}
;
auto tmp13 = tmp10 ? tmp11() : -std::numeric_limits<decltype(tmp11())>::infinity();
auto tmp14 = static_cast<long>(1 + (2*i3));
auto tmp15 = tmp14 >= tmp1;
auto tmp16 = tmp14 < tmp3;
auto tmp17 = tmp15 & tmp16;
auto tmp18 = tmp5 & tmp17;
auto tmp19 = [&]
{
auto tmp20 = in_out_ptr0[256 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp20;
}
;
auto tmp21 = tmp18 ? tmp19() : -std::numeric_limits<decltype(tmp19())>::infinity();
auto tmp22 = (tmp13 != tmp13) ? tmp13 : std::max(tmp21, tmp13);
auto tmp23 = static_cast<long>(2 + (2*i3));
auto tmp24 = tmp23 >= tmp1;
auto tmp25 = tmp23 < tmp3;
auto tmp26 = tmp24 & tmp25;
auto tmp27 = tmp5 & tmp26;
auto tmp28 = [&]
{
auto tmp29 = in_out_ptr0[512 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp29;
}
;
auto tmp30 = tmp27 ? tmp28() : -std::numeric_limits<decltype(tmp28())>::infinity();
auto tmp31 = (tmp22 != tmp22) ? tmp22 : std::max(tmp30, tmp22);
auto tmp32 = static_cast<long>(1 + (2*i2));
auto tmp33 = tmp32 >= tmp1;
auto tmp34 = tmp32 < tmp3;
auto tmp35 = tmp33 & tmp34;
auto tmp36 = tmp35 & tmp9;
auto tmp37 = [&]
{
auto tmp38 = in_out_ptr0[14336 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp38;
}
;
auto tmp39 = tmp36 ? tmp37() : -std::numeric_limits<decltype(tmp37())>::infinity();
auto tmp40 = (tmp31 != tmp31) ? tmp31 : std::max(tmp39, tmp31);
auto tmp41 = tmp35 & tmp17;
auto tmp42 = [&]
{
auto tmp43 = in_out_ptr0[14592 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp43;
}
;
auto tmp44 = tmp41 ? tmp42() : -std::numeric_limits<decltype(tmp42())>::infinity();
auto tmp45 = (tmp40 != tmp40) ? tmp40 : std::max(tmp44, tmp40);
auto tmp46 = tmp35 & tmp26;
auto tmp47 = [&]
{
auto tmp48 = in_out_ptr0[14848 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp48;
}
;
auto tmp49 = tmp46 ? tmp47() : -std::numeric_limits<decltype(tmp47())>::infinity();
auto tmp50 = (tmp45 != tmp45) ? tmp45 : std::max(tmp49, tmp45);
auto tmp51 = static_cast<long>(2 + (2*i2));
auto tmp52 = tmp51 >= tmp1;
auto tmp53 = tmp51 < tmp3;
auto tmp54 = tmp52 & tmp53;
auto tmp55 = tmp54 & tmp9;
auto tmp56 = [&]
{
auto tmp57 = in_out_ptr0[28672 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp57;
}
;
auto tmp58 = tmp55 ? tmp56() : -std::numeric_limits<decltype(tmp56())>::infinity();
auto tmp59 = (tmp50 != tmp50) ? tmp50 : std::max(tmp58, tmp50);
auto tmp60 = tmp54 & tmp17;
auto tmp61 = [&]
{
auto tmp62 = in_out_ptr0[28928 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp62;
}
;
auto tmp63 = tmp60 ? tmp61() : -std::numeric_limits<decltype(tmp61())>::infinity();
auto tmp64 = (tmp59 != tmp59) ? tmp59 : std::max(tmp63, tmp59);
auto tmp65 = tmp54 & tmp26;
auto tmp66 = [&]
{
auto tmp67 = in_out_ptr0[29184 + i1 + (512*i3) + (28672*i2) + (802816*i0)];
return tmp67;
}
;
auto tmp68 = tmp65 ? tmp66() : -std::numeric_limits<decltype(tmp66())>::infinity();
auto tmp69 = (tmp64 != tmp64) ? tmp64 : std::max(tmp68, tmp64);
out_ptr0[i1 + (256*i3) + (7168*i2) + (200704*i0)] = tmp69;
}
}
}
}
}
}
}
''')
```
We always avoid reordering when pre-loop has a loop collapse: https://github.com/pytorch/pytorch/blob/2cbce06feebf5f52ef5539a3b1ae2e003217b6ac/torch/_inductor/ir.py#L2267-L2273.
This PR adds a check that only reuses pre-loop ordering when not having loop collapse.
After this PR, the generated code is
```
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr0)
{
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=0; i0<6422528; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 16*i0);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + 16*i0);
}
#pragma omp for simd simdlen(8)
for(long i0=102760448; i0<102760448; i0+=1)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_out_ptr0[i0];
auto tmp2 = tmp0 * tmp1;
in_out_ptr0[i0] = tmp2;
}
}
{
#pragma omp for
for(long i0=0; i0<128; i0+=1)
{
#pragma GCC ivdep
for(long i1=0; i1<28; i1+=1)
{
#pragma GCC ivdep
for(long i2=0; i2<28; i2+=1)
{
for(long i3=0; i3<16; i3+=1)
{
auto tmp0 = at::vec::Vectorized<int>(static_cast<int>(2*i1));
auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(0));
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = at::vec::Vectorized<int>(static_cast<int>(56));
auto tmp4 = tmp0 < tmp3;
auto tmp5 = tmp2 & tmp4;
auto tmp6 = at::vec::Vectorized<int>(static_cast<int>(2*i2));
auto tmp7 = tmp6 >= tmp1;
auto tmp8 = tmp6 < tmp3;
auto tmp9 = tmp7 & tmp8;
auto tmp10 = tmp5 & tmp9;
auto tmp11 = [&]
{
auto tmp12 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp12;
}
;
auto tmp13 = decltype(tmp11())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp11(), to_float_mask(tmp10) != at::vec::Vectorized<float>(0));
auto tmp14 = at::vec::Vectorized<int>(static_cast<int>(1 + (2*i2)));
auto tmp15 = tmp14 >= tmp1;
auto tmp16 = tmp14 < tmp3;
auto tmp17 = tmp15 & tmp16;
auto tmp18 = tmp5 & tmp17;
auto tmp19 = [&]
{
auto tmp20 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 256 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp20;
}
;
auto tmp21 = decltype(tmp19())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp19(), to_float_mask(tmp18) != at::vec::Vectorized<float>(0));
auto tmp22 = at::vec::maximum(tmp21, tmp13);
auto tmp23 = at::vec::Vectorized<int>(static_cast<int>(2 + (2*i2)));
auto tmp24 = tmp23 >= tmp1;
auto tmp25 = tmp23 < tmp3;
auto tmp26 = tmp24 & tmp25;
auto tmp27 = tmp5 & tmp26;
auto tmp28 = [&]
{
auto tmp29 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 512 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp29;
}
;
auto tmp30 = decltype(tmp28())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp28(), to_float_mask(tmp27) != at::vec::Vectorized<float>(0));
auto tmp31 = at::vec::maximum(tmp30, tmp22);
auto tmp32 = at::vec::Vectorized<int>(static_cast<int>(1 + (2*i1)));
auto tmp33 = tmp32 >= tmp1;
auto tmp34 = tmp32 < tmp3;
auto tmp35 = tmp33 & tmp34;
auto tmp36 = tmp35 & tmp9;
auto tmp37 = [&]
{
auto tmp38 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 14336 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp38;
}
;
auto tmp39 = decltype(tmp37())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp37(), to_float_mask(tmp36) != at::vec::Vectorized<float>(0));
auto tmp40 = at::vec::maximum(tmp39, tmp31);
auto tmp41 = tmp35 & tmp17;
auto tmp42 = [&]
{
auto tmp43 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 14592 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp43;
}
;
auto tmp44 = decltype(tmp42())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp42(), to_float_mask(tmp41) != at::vec::Vectorized<float>(0));
auto tmp45 = at::vec::maximum(tmp44, tmp40);
auto tmp46 = tmp35 & tmp26;
auto tmp47 = [&]
{
auto tmp48 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 14848 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp48;
}
;
auto tmp49 = decltype(tmp47())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp47(), to_float_mask(tmp46) != at::vec::Vectorized<float>(0));
auto tmp50 = at::vec::maximum(tmp49, tmp45);
auto tmp51 = at::vec::Vectorized<int>(static_cast<int>(2 + (2*i1)));
auto tmp52 = tmp51 >= tmp1;
auto tmp53 = tmp51 < tmp3;
auto tmp54 = tmp52 & tmp53;
auto tmp55 = tmp54 & tmp9;
auto tmp56 = [&]
{
auto tmp57 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 28672 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp57;
}
;
auto tmp58 = decltype(tmp56())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp56(), to_float_mask(tmp55) != at::vec::Vectorized<float>(0));
auto tmp59 = at::vec::maximum(tmp58, tmp50);
auto tmp60 = tmp54 & tmp17;
auto tmp61 = [&]
{
auto tmp62 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 28928 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp62;
}
;
auto tmp63 = decltype(tmp61())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp61(), to_float_mask(tmp60) != at::vec::Vectorized<float>(0));
auto tmp64 = at::vec::maximum(tmp63, tmp59);
auto tmp65 = tmp54 & tmp26;
auto tmp66 = [&]
{
auto tmp67 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + 29184 + (16*i3) + (512*i2) + (28672*i1) + (802816*i0));
return tmp67;
}
;
auto tmp68 = decltype(tmp66())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp66(), to_float_mask(tmp65) != at::vec::Vectorized<float>(0));
auto tmp69 = at::vec::maximum(tmp68, tmp64);
tmp69.store(out_ptr0 + (16*i3) + (256*i2) + (7168*i1) + (200704*i0));
}
#pragma omp simd simdlen(8)
for(long i3=256; i3<256; i3+=1)
{
auto tmp0 = static_cast<long>(2*i1);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(56);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = tmp2 & tmp4;
auto tmp6 = static_cast<long>(2*i2);
auto tmp7 = tmp6 >= tmp1;
auto tmp8 = tmp6 < tmp3;
auto tmp9 = tmp7 & tmp8;
auto tmp10 = tmp5 & tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_out_ptr0[i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp12;
}
;
auto tmp13 = tmp10 ? tmp11() : -std::numeric_limits<decltype(tmp11())>::infinity();
auto tmp14 = static_cast<long>(1 + (2*i2));
auto tmp15 = tmp14 >= tmp1;
auto tmp16 = tmp14 < tmp3;
auto tmp17 = tmp15 & tmp16;
auto tmp18 = tmp5 & tmp17;
auto tmp19 = [&]
{
auto tmp20 = in_out_ptr0[256 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp20;
}
;
auto tmp21 = tmp18 ? tmp19() : -std::numeric_limits<decltype(tmp19())>::infinity();
auto tmp22 = (tmp13 != tmp13) ? tmp13 : std::max(tmp21, tmp13);
auto tmp23 = static_cast<long>(2 + (2*i2));
auto tmp24 = tmp23 >= tmp1;
auto tmp25 = tmp23 < tmp3;
auto tmp26 = tmp24 & tmp25;
auto tmp27 = tmp5 & tmp26;
auto tmp28 = [&]
{
auto tmp29 = in_out_ptr0[512 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp29;
}
;
auto tmp30 = tmp27 ? tmp28() : -std::numeric_limits<decltype(tmp28())>::infinity();
auto tmp31 = (tmp22 != tmp22) ? tmp22 : std::max(tmp30, tmp22);
auto tmp32 = static_cast<long>(1 + (2*i1));
auto tmp33 = tmp32 >= tmp1;
auto tmp34 = tmp32 < tmp3;
auto tmp35 = tmp33 & tmp34;
auto tmp36 = tmp35 & tmp9;
auto tmp37 = [&]
{
auto tmp38 = in_out_ptr0[14336 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp38;
}
;
auto tmp39 = tmp36 ? tmp37() : -std::numeric_limits<decltype(tmp37())>::infinity();
auto tmp40 = (tmp31 != tmp31) ? tmp31 : std::max(tmp39, tmp31);
auto tmp41 = tmp35 & tmp17;
auto tmp42 = [&]
{
auto tmp43 = in_out_ptr0[14592 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp43;
}
;
auto tmp44 = tmp41 ? tmp42() : -std::numeric_limits<decltype(tmp42())>::infinity();
auto tmp45 = (tmp40 != tmp40) ? tmp40 : std::max(tmp44, tmp40);
auto tmp46 = tmp35 & tmp26;
auto tmp47 = [&]
{
auto tmp48 = in_out_ptr0[14848 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp48;
}
;
auto tmp49 = tmp46 ? tmp47() : -std::numeric_limits<decltype(tmp47())>::infinity();
auto tmp50 = (tmp45 != tmp45) ? tmp45 : std::max(tmp49, tmp45);
auto tmp51 = static_cast<long>(2 + (2*i1));
auto tmp52 = tmp51 >= tmp1;
auto tmp53 = tmp51 < tmp3;
auto tmp54 = tmp52 & tmp53;
auto tmp55 = tmp54 & tmp9;
auto tmp56 = [&]
{
auto tmp57 = in_out_ptr0[28672 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp57;
}
;
auto tmp58 = tmp55 ? tmp56() : -std::numeric_limits<decltype(tmp56())>::infinity();
auto tmp59 = (tmp50 != tmp50) ? tmp50 : std::max(tmp58, tmp50);
auto tmp60 = tmp54 & tmp17;
auto tmp61 = [&]
{
auto tmp62 = in_out_ptr0[28928 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp62;
}
;
auto tmp63 = tmp60 ? tmp61() : -std::numeric_limits<decltype(tmp61())>::infinity();
auto tmp64 = (tmp59 != tmp59) ? tmp59 : std::max(tmp63, tmp59);
auto tmp65 = tmp54 & tmp26;
auto tmp66 = [&]
{
auto tmp67 = in_out_ptr0[29184 + i3 + (512*i2) + (28672*i1) + (802816*i0)];
return tmp67;
}
;
auto tmp68 = tmp65 ? tmp66() : -std::numeric_limits<decltype(tmp66())>::infinity();
auto tmp69 = (tmp64 != tmp64) ? tmp64 : std::max(tmp68, tmp64);
out_ptr0[i3 + (256*i2) + (7168*i1) + (200704*i0)] = tmp69;
}
}
}
}
}
}
```
After this PR, we can get a **18%** performance improvement for timm **ese_vovnet19b_dw** on skx-4148(```python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/timm_models.py --performance --float32 -dcpu -n50 --inductor --channels-last --no-skip --dashboard --only ese_vovnet19b_dw```):
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96640
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel