pytorch
ce9923a1 - [Quant][PT2E][Inductor] Lower quantized conv to Inductor (#101164)

Commit
1 year ago
[Quant][PT2E][Inductor] Lower quantized conv to Inductor (#101164) **Summary** Enable the lowering path for reference quantized conv after PT2E to Inductor. The pattern `decomposed dequantize -> aten.convolution -> decomposed quantize` is fused to `quantized.functional.conv1d/2d/3d` and Inductor makes external calls to these ops. This PR focuses on functionality only. The implementation is expected to have low performance. Code example: ```Python class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 6, 2, stride=2, padding=0, dilation=1) def forward(self, x): return nn.functional.gelu(self.conv(x)) m = M().eval() example_inputs = (torch.randn(2, 3, 6, 6),) exported_model, guards = torchdynamo.export( m, *copy.deepcopy(example_inputs), aten_graph=True, tracing_mode="real", ) qconfig = get_default_qconfig("x86") qconfig_mapping = QConfigMapping().set_global(qconfig) backend_config_inductor = get_x86_inductor_pt2e_backend_config() prepared_model = prepare_pt2e( exported_model, qconfig_mapping, example_inputs, backend_config_inductor ) prepared_model(*example_inputs) converted_model = convert_pt2e(prepared_model) run = compile_fx(converted_model, example_inputs) ``` Output code by Inductor ```python from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() kernel_cpp_0 = async_compile.cpp(''' #include "/tmp/torchinductor_weiwen/5d/c5dsrjrcd4jlzryilhxl5hdvcrzsoek52xzzqqy57hcoezvxxxwm.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const long* in_ptr2, unsigned char* out_ptr0) { { #pragma GCC ivdep for(long i0=static_cast<long>(0L); i0<static_cast<long>(2L); i0+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i1=static_cast<long>(0L); i1<static_cast<long>(3L); i1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i2=static_cast<long>(0L); i2<static_cast<long>(36L); i2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(i2 + (36L*i1) + (108L*i0))]; auto tmp1 = in_ptr1[static_cast<long>(0L)]; auto tmp7 = in_ptr2[static_cast<long>(0L)]; auto tmp2 = 1 / tmp1; auto tmp3 = static_cast<float>(1.0); auto tmp4 = decltype(tmp2)(tmp2 * tmp3); auto tmp5 = decltype(tmp0)(tmp0 * tmp4); auto tmp6 = std::nearbyint(tmp5); auto tmp8 = static_cast<float>(tmp7); auto tmp9 = tmp6 + tmp8; auto tmp10 = static_cast<float>(0); auto tmp11 = max_propagate_nan(tmp9, tmp10); auto tmp12 = static_cast<float>(127); auto tmp13 = min_propagate_nan(tmp11, tmp12); auto tmp14 = static_cast<unsigned char>(tmp13); out_ptr0[static_cast<long>(i1 + (3L*i2) + (108L*i0))] = tmp14; } } } } } ''') kernel_cpp_1 = async_compile.cpp(''' #include "/tmp/torchinductor_weiwen/5d/c5dsrjrcd4jlzryilhxl5hdvcrzsoek52xzzqqy57hcoezvxxxwm.h" extern "C" void kernel(const unsigned char* in_ptr0, const long* in_ptr1, const float* in_ptr2, float* out_ptr0) { { #pragma GCC ivdep for(long i0=static_cast<long>(0L); i0<static_cast<long>(2L); i0+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i1=static_cast<long>(0L); i1<static_cast<long>(6L); i1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long i2=static_cast<long>(0L); i2<static_cast<long>(9L); i2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(i1 + (6L*i2) + (54L*i0))]; auto tmp2 = in_ptr1[static_cast<long>(0L)]; auto tmp5 = in_ptr2[static_cast<long>(0L)]; auto tmp1 = static_cast<float>(tmp0); auto tmp3 = static_cast<float>(tmp2); auto tmp4 = tmp1 - tmp3; auto tmp6 = decltype(tmp4)(tmp4 * tmp5); auto tmp7 = static_cast<float>(0.5); auto tmp8 = decltype(tmp6)(tmp6 * tmp7); auto tmp9 = static_cast<float>(0.7071067811865476); auto tmp10 = decltype(tmp6)(tmp6 * tmp9); auto tmp11 = std::erf(tmp10); auto tmp12 = static_cast<float>(1); auto tmp13 = tmp11 + tmp12; auto tmp14 = decltype(tmp8)(tmp8 * tmp13); out_ptr0[static_cast<long>(i2 + (9L*i1) + (54L*i0))] = tmp14; } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1 = args args.clear() buf0 = torch.ops.quantized_decomposed.quantize_per_channel.default(arg0_1, arg4_1, arg5_1, 0, -128, 127, torch.int8) del arg0_1 buf1 = buf0 assert_size_stride(buf1, (6, 3, 2, 2), (12, 4, 2, 1)) del buf0 buf2 = empty_strided((2, 3, 6, 6), (108, 1, 18, 3), device='cpu', dtype=torch.uint8) kernel_cpp_0(c_void_p(arg8_1.data_ptr()), c_void_p(arg2_1.data_ptr()), c_void_p(arg3_1.data_ptr()), c_void_p(buf2.data_ptr())) del arg8_1 buf2 = torch._make_per_tensor_quantized_tensor(buf2, arg2_1, arg3_1) buf1 = torch._make_per_channel_quantized_tensor(buf1, arg4_1, arg5_1, 0) buf3 = torch.ao.nn.quantized.functional.conv2d(buf2, buf1, arg1_1, (2, 2), (0, 0), (1, 1), 1, 'zeros', arg6_1, arg7_1, torch.uint8) assert_size_stride(buf3, (2, 6, 3, 3), (54, 1, 18, 6)) del arg1_1 del arg2_1 del arg3_1 del arg4_1 del arg5_1 del buf1 del buf2 buf4 = empty_strided((2, 6, 3, 3), (54, 9, 3, 1), device='cpu', dtype=torch.float32) kernel_cpp_1(c_void_p(buf3.data_ptr()), c_void_p(arg7_1.data_ptr()), c_void_p(arg6_1.data_ptr()), c_void_p(buf4.data_ptr())) del arg6_1 del arg7_1 return (buf4, ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((6, 3, 2, 2), (12, 4, 2, 1), device='cpu', dtype=torch.float32) arg1_1 = rand_strided((6, ), (1, ), device='cpu', dtype=torch.float32) arg2_1 = rand_strided((), (), device='cpu', dtype=torch.float32) arg3_1 = rand_strided((), (), device='cpu', dtype=torch.int64) arg4_1 = rand_strided((6, ), (1, ), device='cpu', dtype=torch.float32) arg5_1 = rand_strided((6, ), (1, ), device='cpu', dtype=torch.int64) arg6_1 = rand_strided((), (), device='cpu', dtype=torch.float32) arg7_1 = rand_strided((), (), device='cpu', dtype=torch.int64) arg8_1 = rand_strided((2, 3, 6, 6), (108, 36, 6, 1), device='cpu', dtype=torch.float32) return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1]), times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.utils import compiled_module_main compiled_module_main('None', benchmark_compiled_module) ``` **Test plan** python test/test_quantization.py TestQuantizePT2EFXX86Inductor.test_inductor_qconv_lowering Pull Request resolved: https://github.com/pytorch/pytorch/pull/101164 Approved by: https://github.com/jgong5, https://github.com/jansel
Author
Committer
Parents
Loading