pytorch
91279f94 - [inductor][quant]Enable inductor vec codegen for quantization (#98489)

Commit
1 year ago
[inductor][quant]Enable inductor vec codegen for quantization (#98489) **Summary** Enable the `decomposed dequant - pointwise ops - decomposed quant` vectorization code gen inside inductor. Here is the example in the UT and the generated code: Example: * `decomposed dequant - relu - decomposed quant` pattern. * Using `uint8` as the quantized tensor data type. Generated Code: ``` kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_root/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h" extern "C" void kernel(const unsigned char* in_ptr0, unsigned char* out_ptr0) { #pragma omp parallel num_threads(56) { { #pragma omp for for(long i0=static_cast<long>(0); i0<static_cast<long>(27); i0+=static_cast<long>(1)) { auto tmp0 = at::vec::load_uint8_as_float(in_ptr0 + static_cast<long>(16*i0)); auto tmp1 = (tmp0); auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100)); auto tmp3 = tmp1 - tmp2; auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01)); auto tmp5 = tmp3 * tmp4; auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0)); auto tmp7 = at::vec::Vectorized<float>(static_cast<float>(100.0)); auto tmp8 = tmp6 * tmp7; auto tmp9 = tmp8.round(); auto tmp10 = tmp9 + tmp2; auto tmp11 = at::vec::Vectorized<float>(static_cast<float>(0)); auto tmp12 = at::vec::maximum(tmp10, tmp11); auto tmp13 = at::vec::Vectorized<float>(static_cast<float>(255)); auto tmp14 = at::vec::minimum(tmp12, tmp13); auto tmp15 = (tmp14); tmp15.store_float_as_uint8(out_ptr0 + static_cast<long>(16*i0)); } #pragma omp for simd simdlen(8) for(long i0=static_cast<long>(432); i0<static_cast<long>(441); i0+=static_cast<long>(1)) { auto tmp0 = in_ptr0[static_cast<long>(i0)]; auto tmp1 = static_cast<float>(tmp0); auto tmp2 = static_cast<float>(100); auto tmp3 = tmp1 - tmp2; auto tmp4 = static_cast<float>(0.01); auto tmp5 = tmp3 * tmp4; auto tmp6 = tmp5 * (tmp5>0); auto tmp7 = static_cast<float>(100.0); auto tmp8 = tmp6 * tmp7; auto tmp9 = std::nearbyint(tmp8); auto tmp10 = tmp9 + tmp2; auto tmp11 = static_cast<float>(0); auto tmp12 = (tmp11 != tmp11) ? tmp11 : std::max(tmp10, tmp11); auto tmp13 = static_cast<float>(255); auto tmp14 = (tmp13 != tmp13) ? tmp13 : std::min(tmp12, tmp13); auto tmp15 = static_cast<unsigned char>(tmp14); out_ptr0[static_cast<long>(i0)] = tmp15; } } } } ''') ``` **Test Plan** ``` cd test/inductor && python -m pytest test_cpu_repro.py -k test_decomposed_dequant_relu_quant ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/98489 Approved by: https://github.com/jgong5, https://github.com/jansel
Committer
Parents
Loading