inductor(cpu): fix C++ compile error when sigmoid's post ops is a reduction op (#94890) (#95054)
For timm **nfnet_l0** model. CPU path has the following error: `torch._dynamo.exc.BackendCompilerFailed: inductor raised CppCompileError: C++ compile error`.
There has a simple test case:
```
def fn(x):
x = torch.ops.aten.sigmoid.default(x)
return torch.ops.aten.mean.dim(x, [-1, -2], True)
x = torch.randn((1, 8, 8, 8))
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(x)
real_out = fn(x)
compiled_out = opt_fn(x)
tol = 0.0001
print(torch.allclose(real_out, compiled_out, atol=tol, rtol=tol))
```
before:
```
extern "C" void kernel(float* __restrict__ in_out_ptr0,
const float* __restrict__ in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
#pragma GCC ivdep
for(long i0=0; i0<8; i0+=1)
{
{
#pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out += omp_in) initializer(omp_priv={{0}})
float tmp2 = 0;
auto tmp2_vec = at::vec::Vectorized<float>(tmp2);
for(long i1=0; i1<4; i1+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + (16*i1) + (64*i0));
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
tmp2_vec += tmp1;
}
#pragma omp simd simdlen(8) reduction(+:tmp3)
for(long i1=64; i1<64; i1+=1)
{
auto tmp0 = in_ptr0[i1 + (64*i0)];
auto tmp1 = std::exp(-tmp0);
auto tmp2 = 1 / (1 + tmp1);
tmp3 += tmp2;
}
tmp2 += at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;}, tmp2_vec);
out_ptr0[i0] = tmp3;
}
}
}
{
for(long i0=0; i0<0; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(64));
auto tmp2 = tmp0 / tmp1;
tmp2.store(in_out_ptr0 + 16*i0);
}
#pragma omp simd simdlen(8)
for(long i0=0; i0<8; i0+=1)
{
auto tmp0 = out_ptr0[i0];
auto tmp1 = static_cast<float>(64);
auto tmp2 = tmp0 / tmp1;
in_out_ptr0[i0] = tmp2;
}
}
}
```
after:
```
extern "C" void kernel(float* __restrict__ in_out_ptr0,
const float* __restrict__ in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=0; i0<8; i0+=1)
{
{
#pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out += omp_in) initializer(omp_priv={{0}})
float tmp2 = 0;
auto tmp2_vec = at::vec::Vectorized<float>(tmp2);
for(long i1=0; i1<4; i1+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + (16*i1) + (64*i0));
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
tmp2_vec += tmp1;
}
#pragma omp simd simdlen(8) reduction(+:tmp2)
for(long i1=64; i1<64; i1+=1)
{
auto tmp0 = in_ptr0[i1 + (64*i0)];
auto tmp1 = decltype(tmp0)(1) / (decltype(tmp0)(1) + std::exp(-tmp0));
tmp2 += tmp1;
}
tmp2 += at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;}, tmp2_vec);
out_ptr0[i0] = tmp2;
}
}
}
#pragma omp single
{
{
for(long i0=0; i0<0; i0+=1)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + 16*i0);
auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(64));
auto tmp2 = tmp0 / tmp1;
tmp2.store(in_out_ptr0 + 16*i0);
}
#pragma omp simd simdlen(8)
for(long i0=0; i0<8; i0+=1)
{
auto tmp0 = out_ptr0[i0];
auto tmp1 = static_cast<float>(64);
auto tmp2 = tmp0 / tmp1;
in_out_ptr0[i0] = tmp2;
}
}
}
}
}
''')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94890
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/lezcano