pytorch
1204463b - inductor: fix bfloat16 reduction crash issue which store float value to bfloat16 (#102719)

Commit
1 year ago
inductor: fix bfloat16 reduction crash issue which store float value to bfloat16 (#102719) For bfloat16 reduction, there has an wrong store issue which store float value as bfloat16: Before: ``` extern "C" void kernel(const bfloat16* in_ptr0, bfloat16* out_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(40) { { #pragma omp for for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={{-std::numeric_limits<float>::infinity()}}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0); for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L)) { auto tmp0 = load_bf16_as_float(in_ptr0 + static_cast<long>(i0 + (16L*i1))); auto tmp1 = (tmp0); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1); } tmp_acc0_vec.store(out_ptr0 + static_cast<long>(i0)); } } } #pragma omp single { { for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L)) { auto tmp0 = load_bf16_as_float(out_ptr0 + static_cast<long>(i0)); auto tmp1 = (tmp0); tmp1.store(out_ptr1 + static_cast<long>(i0)); } } } } } ''') ``` after: ``` extern "C" void kernel(const bfloat16* in_ptr0, bfloat16* out_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(40) { { #pragma omp for for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={{-std::numeric_limits<float>::infinity()}}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0); for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L)) { auto tmp0 = load_bf16_as_float(in_ptr0 + static_cast<long>(i0 + (16L*i1))); auto tmp1 = (tmp0); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1); } store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp_acc0_vec); } } } #pragma omp single { { for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L)) { auto tmp0 = load_bf16_as_float(out_ptr0 + static_cast<long>(i0)); auto tmp1 = (tmp0); tmp1.store(out_ptr1 + static_cast<long>(i0)); } } } } } ''') ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/102719 Approved by: https://github.com/jansel, https://github.com/jgong5
Author
Committer
Parents
Loading