inductor: fix bfloat16 reduction crash issue which store float value to bfloat16 (#102719)
For bfloat16 reduction, there has an wrong store issue which store float value as bfloat16:
Before:
```
extern "C" void kernel(const bfloat16* in_ptr0,
bfloat16* out_ptr0,
float* out_ptr1)
{
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={{-std::numeric_limits<float>::infinity()}})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
{
auto tmp0 = load_bf16_as_float(in_ptr0 + static_cast<long>(i0 + (16L*i1)));
auto tmp1 = (tmp0);
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1);
}
tmp_acc0_vec.store(out_ptr0 + static_cast<long>(i0));
}
}
}
#pragma omp single
{
{
for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
{
auto tmp0 = load_bf16_as_float(out_ptr0 + static_cast<long>(i0));
auto tmp1 = (tmp0);
tmp1.store(out_ptr1 + static_cast<long>(i0));
}
}
}
}
}
''')
```
after:
```
extern "C" void kernel(const bfloat16* in_ptr0,
bfloat16* out_ptr0,
float* out_ptr1)
{
#pragma omp parallel num_threads(40)
{
{
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={{-std::numeric_limits<float>::infinity()}})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
{
auto tmp0 = load_bf16_as_float(in_ptr0 + static_cast<long>(i0 + (16L*i1)));
auto tmp1 = (tmp0);
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1);
}
store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp_acc0_vec);
}
}
}
#pragma omp single
{
{
for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
{
auto tmp0 = load_bf16_as_float(out_ptr0 + static_cast<long>(i0));
auto tmp1 = (tmp0);
tmp1.store(out_ptr1 + static_cast<long>(i0));
}
}
}
}
}
''')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102719
Approved by: https://github.com/jansel, https://github.com/jgong5