pytorch
a6caa9c5 - Add a cpp wrapper for Inductor (#88167)

Commit
2 years ago
Add a cpp wrapper for Inductor (#88167) ## Description Implements https://github.com/pytorch/torchdynamo/issues/1556. This PR adds a cpp wrapper to invoke the generated kernels. The cpp wrapper is turned off by default and can be turned on by setting: ```python from torch._inductor import config config.cpp_wrapper = True ``` ### Example The main part of the generated code: ```python from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include <dlfcn.h> #include <assert.h> std::tuple<at::Tensor, at::Tensor> call_0(std::tuple<at::Tensor, at::Tensor> args) { at::Tensor arg0_1, arg1_1; std::tie(arg0_1, arg1_1) = args; auto buf0 = at::empty_strided({8, 8}, {8, 1}, at::ScalarType::Float); auto buf1 = at::empty_strided({8, 8}, {1, 8}, at::ScalarType::Float); auto kernel0_lib = dlopen("/tmp/torchinductor_user/kn/ckn7ubcn2qbkme2vx5r6antnh5sv6d3o3t6qwdfgfoupnxty6pnm.so", RTLD_NOW); assert(kernel0_lib != nullptr); void (*kernel0)(const float*,const float*,float*,float*); *(void **) (&kernel0) = dlsym(kernel0_lib, "kernel"); kernel0((float*)(arg0_1.data_ptr()), (float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()), (float*)(buf1.data_ptr())); arg0_1.reset(); arg1_1.reset(); return std::make_tuple(buf0, buf1); }''' ) module = load_inline( name='inline_extension_c64wpbccpbre3th2k6oxwrjy5bhvxnmkdxkhcfxlsw7xpsg4eabu', cpp_sources=[wrapper], functions=['call_0'], extra_cflags=['-fPIC -Wall -std=c++14 -Wno-unused-variable -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp'], extra_ldflags=['-shared -lgomp'], extra_include_paths=['-I/home/user/pytorch/torch/include -I/home/user/pytorch/torch/include/torch/csrc/api/include -I/home/user/pytorch/torch/include/TH -I/home/user/pytorch/torch/include/THC -I/home/user/miniconda3/envs/pytorch/include/python3.7m']) def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_0) ``` ### Next steps The below items will be addressed in upcoming PRs. - [x] Support Reduction: #88561 - [x] Support None: #88560 - [ ] Support ExternKernel - [x] ATen GEMM-related OPs: #88667 - [ ] ATen Conv - [ ] Conv/GEMM fusion OPs - [x] Cache the kernel loading part: #89742 - [ ] De-allocate input buffers when possible by leveraging CPython APIs - [ ] Support Constant Pull Request resolved: https://github.com/pytorch/pytorch/pull/88167 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
Author
Committer
Parents
Loading