[Quant][PT2E][Inductor] Lower quantized conv to Inductor (#101164)
**Summary**
Enable the lowering path for reference quantized conv after PT2E to Inductor.
The pattern `decomposed dequantize -> aten.convolution -> decomposed quantize` is fused to `quantized.functional.conv1d/2d/3d` and Inductor makes external calls to these ops.
This PR focuses on functionality only. The implementation is expected to have low performance.
Code example:
```Python
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, 2, stride=2, padding=0, dilation=1)
def forward(self, x):
return nn.functional.gelu(self.conv(x))
m = M().eval()
example_inputs = (torch.randn(2, 3, 6, 6),)
exported_model, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config_inductor = get_x86_inductor_pt2e_backend_config()
prepared_model = prepare_pt2e(
exported_model,
qconfig_mapping,
example_inputs,
backend_config_inductor
)
prepared_model(*example_inputs)
converted_model = convert_pt2e(prepared_model)
run = compile_fx(converted_model, example_inputs)
```
Output code by Inductor
```python
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_weiwen/5d/c5dsrjrcd4jlzryilhxl5hdvcrzsoek52xzzqqy57hcoezvxxxwm.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const long* in_ptr2,
unsigned char* out_ptr0)
{
{
#pragma GCC ivdep
for(long i0=static_cast<long>(0L); i0<static_cast<long>(2L); i0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(3L); i1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i2=static_cast<long>(0L); i2<static_cast<long>(36L); i2+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i2 + (36L*i1) + (108L*i0))];
auto tmp1 = in_ptr1[static_cast<long>(0L)];
auto tmp7 = in_ptr2[static_cast<long>(0L)];
auto tmp2 = 1 / tmp1;
auto tmp3 = static_cast<float>(1.0);
auto tmp4 = decltype(tmp2)(tmp2 * tmp3);
auto tmp5 = decltype(tmp0)(tmp0 * tmp4);
auto tmp6 = std::nearbyint(tmp5);
auto tmp8 = static_cast<float>(tmp7);
auto tmp9 = tmp6 + tmp8;
auto tmp10 = static_cast<float>(0);
auto tmp11 = max_propagate_nan(tmp9, tmp10);
auto tmp12 = static_cast<float>(127);
auto tmp13 = min_propagate_nan(tmp11, tmp12);
auto tmp14 = static_cast<unsigned char>(tmp13);
out_ptr0[static_cast<long>(i1 + (3L*i2) + (108L*i0))] = tmp14;
}
}
}
}
}
''')
kernel_cpp_1 = async_compile.cpp('''
#include "/tmp/torchinductor_weiwen/5d/c5dsrjrcd4jlzryilhxl5hdvcrzsoek52xzzqqy57hcoezvxxxwm.h"
extern "C" void kernel(const unsigned char* in_ptr0,
const long* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(long i0=static_cast<long>(0L); i0<static_cast<long>(2L); i0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(6L); i1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long i2=static_cast<long>(0L); i2<static_cast<long>(9L); i2+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (6L*i2) + (54L*i0))];
auto tmp2 = in_ptr1[static_cast<long>(0L)];
auto tmp5 = in_ptr2[static_cast<long>(0L)];
auto tmp1 = static_cast<float>(tmp0);
auto tmp3 = static_cast<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp6 = decltype(tmp4)(tmp4 * tmp5);
auto tmp7 = static_cast<float>(0.5);
auto tmp8 = decltype(tmp6)(tmp6 * tmp7);
auto tmp9 = static_cast<float>(0.7071067811865476);
auto tmp10 = decltype(tmp6)(tmp6 * tmp9);
auto tmp11 = std::erf(tmp10);
auto tmp12 = static_cast<float>(1);
auto tmp13 = tmp11 + tmp12;
auto tmp14 = decltype(tmp8)(tmp8 * tmp13);
out_ptr0[static_cast<long>(i2 + (9L*i1) + (54L*i0))] = tmp14;
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1 = args
args.clear()
buf0 = torch.ops.quantized_decomposed.quantize_per_channel.default(arg0_1, arg4_1, arg5_1, 0, -128, 127, torch.int8)
del arg0_1
buf1 = buf0
assert_size_stride(buf1, (6, 3, 2, 2), (12, 4, 2, 1))
del buf0
buf2 = empty_strided((2, 3, 6, 6), (108, 1, 18, 3), device='cpu', dtype=torch.uint8)
kernel_cpp_0(c_void_p(arg8_1.data_ptr()), c_void_p(arg2_1.data_ptr()), c_void_p(arg3_1.data_ptr()), c_void_p(buf2.data_ptr()))
del arg8_1
buf2 = torch._make_per_tensor_quantized_tensor(buf2, arg2_1, arg3_1)
buf1 = torch._make_per_channel_quantized_tensor(buf1, arg4_1, arg5_1, 0)
buf3 = torch.ao.nn.quantized.functional.conv2d(buf2, buf1, arg1_1, (2, 2), (0, 0), (1, 1), 1, 'zeros', arg6_1, arg7_1, torch.uint8)
assert_size_stride(buf3, (2, 6, 3, 3), (54, 1, 18, 6))
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del buf1
del buf2
buf4 = empty_strided((2, 6, 3, 3), (54, 9, 3, 1), device='cpu', dtype=torch.float32)
kernel_cpp_1(c_void_p(buf3.data_ptr()), c_void_p(arg7_1.data_ptr()), c_void_p(arg6_1.data_ptr()), c_void_p(buf4.data_ptr()))
del arg6_1
del arg7_1
return (buf4, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((6, 3, 2, 2), (12, 4, 2, 1), device='cpu', dtype=torch.float32)
arg1_1 = rand_strided((6, ), (1, ), device='cpu', dtype=torch.float32)
arg2_1 = rand_strided((), (), device='cpu', dtype=torch.float32)
arg3_1 = rand_strided((), (), device='cpu', dtype=torch.int64)
arg4_1 = rand_strided((6, ), (1, ), device='cpu', dtype=torch.float32)
arg5_1 = rand_strided((6, ), (1, ), device='cpu', dtype=torch.int64)
arg6_1 = rand_strided((), (), device='cpu', dtype=torch.float32)
arg7_1 = rand_strided((), (), device='cpu', dtype=torch.int64)
arg8_1 = rand_strided((2, 3, 6, 6), (108, 36, 6, 1), device='cpu', dtype=torch.float32)
return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.utils import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
```
**Test plan**
python test/test_quantization.py TestQuantizePT2EFXX86Inductor.test_inductor_qconv_lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101164
Approved by: https://github.com/jgong5, https://github.com/jansel