onnxruntime
db6d83b6 - [EP ABI] Initial support for kernel-based EPs (#26206)

Commit
53 days ago
[EP ABI] Initial support for kernel-based EPs (#26206) ### Description This PR adds an initial set of C APIs necessary to support kernel registration for plugin EPs. ### Example use The example plugin EP implementation now registers `MemcpyFromHost` and `MemcpyToHost` operator kernels using the new APIs. New utilities in the example implementation make the process of defining operator kernels very similar to the existing process used by provider-bridge EPs. First, the operator kernel class is defined: ```c++ // File: onnxruntime/test/autoep/library/kernels/memcpy.h struct Memcpy : public OrtKernelImpl { static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Memcpy>& kernel); Memcpy(const OrtKernelInfo* info, void* state); static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) noexcept; private: const OrtKernelInfo* info_; void* state_; // Custom state passed from OrtEp }; ``` Then, a macro defines a function that can be called to register the operator with the EP's kernel registry: ```c++ // File: onnxruntime/test/autoep/library/kernels/memcpy.cc ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, 1, (Ort::KernelDefBuilder() .SetInputMemType(0, OrtMemType::OrtMemTypeCPUInput) .AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), Memcpy) ONNX_OPERATOR_KERNEL_EX( MemcpyToHost, kOnnxDomain, 1, (Ort::KernelDefBuilder() .SetOutputMemType(0, OrtMemType::OrtMemTypeCPUOutput) .AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), Memcpy) ``` Lastly, the functions defined by the above macro are entered into a table: ```c++ // File: onnxruntime/test/autoep/library/ep_kernel_registration.cc // Include kernel files: #include "kernels/memcpy.h" // Forward declarations of kernel classes used as template args for BuildKernelCreateInfo class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyToHost); // Table of BuildKernelCreateInfo functions for each operator static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { BuildKernelCreateInfo<void>, // Dummy to avoid table becoming empty. BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyFromHost)>, BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 1, MemcpyToHost)>, }; ``` The [example EP processes the entries in the above table](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/ep_kernel_registration.cc) to add information about the supported operator kernels to the EP's kernel registry (`OrtKernelRegistry`). Additionally, during the call to `OrtEp::GetCapability`, an EP can now lookup registered kernel definitions via the new API `EpGraphSupportInfo_LookUpKernel`. Note that an EP would not normally lookup kernels for `Memcpy**Host`, which are inserted by ORT. Instead, it would be used to look up other registered operator kernels like `Conv`, for example. ```c++ static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) noexcept { // ... for (const OrtNode* node : nodes) { const OrtKernelDef* kernel_def = nullptr; OrtStatus* status = this_ep->ep_api->EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def); if (status != nullptr) { return status; } if (kernel_def != nullptr) { // Take node if this EP has a registered kernel for it. if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, node); st != nullptr) { return st; } } } return nullptr; } ``` ### EP implementation details An EP instance (i.e., `OrtEp`) that needs to register operator kernels with ONNX Runtime must implement the following `OrtEp::GetKernelRegistry()` function: | Function Signature | Description | |--------------------|-------------| |**GetKernelRegistry**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtEp* this_ptr`: The OrtEp instance.</li><li>`const OrtKernelRegistry** kernel_registry`: Output parameter set to the EP's kernel registry, which must remain valid throughout the lifetime of the EP.</li></ul>| Gets the execution provider's kernel registry, if any.<br/><br/>**Remarks:** A kernel registry contains kernel creation information for operator kernels supported by an EP.<br/><br/>**Note:** Implementation of this function is optional. If set to NULL, ORT assumes the EP compiles nodes. | If defined by the EP, the `OrtEp::GetKernelRegistry()` function is [called by ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/0f7145f3809103c123de2d281a6b310677e6d56c/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc#L146-L147) after creating an instance of the `OrtEp` in order to retrieve the EP's kernel registry. #### APIs used by EP to add entries to kernel registry An EP's kernel registry (`OrtKernelRegistry`) contains **information** necessary for the (later) creation of operator kernels supported by an EP. Conceptually, a kernel registry contains an array of "kernel creation information" elements, one per operator. Each such element consists of: - A kernel **definition** (`OrtKernelDef`), which specifies operator type, supported versions, type constraints, I/O memory types, etc. - A function of type `OrtKernelCreateFunc` that ORT calls to create an instance of the kernel (`OrtKernelImpl`). - Custom opaque state (provided by the `OrtEp`) that is passed to the `OrtKernelCreateFunc`. An EP uses the following `OrtEpApi::KernelRegistry_AddKernel()` function to add an entry for one supported operator. | Function Signature | Description | |--------------------|-------------| |**KernelRegistry_AddKernel**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelRegistry* kernel_registry`: The OrtKernelRegistry instance.</li><li>`const OrtKernelDef* kernel_def`: The kernel definition, which includes operator type, version, EP name, type constraints, etc.</li><li>`OrtKernelCreateFunc kernel_create_func`: Function that creates an instance of the operator kernel as a OrtKernelImpl instance.</li><li>`void* kernel_create_func_state`: Custom state passed to the kernel creation function. Can be null.</li></ul>| Adds kernel creation information for a supported operator kernel to the given kernel registry.<br/><br/>**Remarks:** Refer to OrtEp::GetKernelRegistry, which returns an EP's kernel registry to ORT. | ##### Building a kernel definition An EP uses a kernel definition builder (`OrtKernelDefBuilder`) to create a kernel definition (`OrtKernelDef`). The following table lists **some** of the C APIs related to building a kernel definition. The above `ONNX_OPERATOR_KERNEL_EX` macro [uses these APIs](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/kernels/utils.h#L42). | Function Signature | Description | |--------------------|-------------| |**KernelDefBuilder_SetOperatorType**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelDefBuilder* kernel_def_builder`: The OrtKernelDefBuilder instance.</li><li>`const char* op_type`: A null-terminated string representing the operator type.</li></ul>| Sets the kernel's operator type. | |**KernelDefBuilder_SetDomain**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelDefBuilder* kernel_def_builder`: The OrtKernelDefBuilder instance.</li><li>`const char* domain`: A null-terminated string representing the operator's domain.</li></ul>| Sets the kernel's domain. | | ... | ... | |**KernelDefBuilder_Build**<br/><br/>**Returns**:`OrtStatus*`<br/><br/>**Parameters:**<br/><ul><li>`OrtKernelDefBuilder* kernel_def_builder`: The OrtKernelDefBuilder instance.</li><li>`OrtKernelDef** kernel_def_out`: The new OrtKernelDef instance.</li></ul>| Creates a OrtKernelDef instance from the given kernel definition builder. | ##### Defining a kernel implementation An EP defines a kernel implementation by initializing an instance of `OrtKernelImpl` (shown below) with function pointers for computation, release, etc. ```c++ struct OrtKernelImpl { uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION /** \brief Computation function called to execute the kernel on an EP. * * \param[in] this_ptr The OrtKernelImpl instance. * \param[in] context The OrtKernelContext instance that provides access to the inputs and outputs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.24. */ ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context); /** \brief Called by ORT to release the OrtKernelImpl instance and its resources. * * \param[in] this_ptr The OrtKernelImpl instance. * * \since Version 1.24. */ ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr); }; ``` As shown previously, the example EP creates a `Memcpy` class that inherits from `OrtKernelImpl` and [implements the above functions](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/kernels/memcpy.cc). ##### Defining a kernel creation function An EP must provide a function of type `OrtKernelCreateFunc` that ORT can later call to create an instance of a kernel (`OrtKernelImpl`). The signature of the `OrtKernelCreateFunc` is shown below. ```c++ /** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel. * * \param[in] ctx Unused/reserved for future use. * \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel. * Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null. * \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics. * \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.24. */ typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ OrtKernelCreateContext* ctx, // unused/reserved as of 1.24 _In_ void* kernel_create_func_state, _In_ const OrtKernelInfo* info, _Outptr_result_maybenull_ OrtKernelImpl** kernel_out); ``` The example EP declares kernel creation functions via use of the previously mentioned `ONNX_OPERATOR_KERNEL_EX` [macro](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-kernel-based-eps/onnxruntime/test/autoep/library/kernels/utils.h#L56-L64). If one were to expand the macro call, the kernel creation function for `MemcpyFromHost` would look similar to the following snippet: ```c++ OrtStatus* ORT_API_CALL CreateMemcpyKernel(OrtKernelCreateContext* /*ctx*/, void* kernel_create_func_state, const OrtKernelInfo* info, OrtKernelImpl** kernel_out) { *kernel_out = nullptr; std::unique_ptr<Memcpy> kernel; RETURN_IF_ERROR(Memcpy::Create(info, kernel_create_func_state, kernel)); *kernel_out = kernel.release(); return nullptr; } ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Parents
Loading