onnxruntime
3c6fa0ef - Add support for EP selection delegate callback to Python bindings (#24634)

Commit
291 days ago
Add support for EP selection delegate callback to Python bindings (#24634) ### Description Follow up to https://github.com/microsoft/onnxruntime/pull/24614 Example Python program (adapted from unit tests) that specifies a custom EP selection function to select a OrtEpDevice(s) for compiling: ```python def test_compile_with_ep_selection_delegate(self): # ... # User's custom EP selection function. def my_delegate( ep_devices: Sequence[onnxrt.OrtEpDevice], model_metadata: dict[str, str], runtime_metadata: dict[str, str], max_selections: int, ) -> Sequence[onnxrt.OrtEpDevice]: self.assertTrue(len(model_metadata) > 0) self.assertTrue(ep_devices and max_selections > 0) # Select the first and last devices (if there are more than one) selected_devices = [ep_devices[0]] if max_selections > 2 and len(ep_devices) > 1: selected_devices.append(ep_devices[-1]) # ORT CPU EP is always last return selected_devices session_options = onnxrt.SessionOptions() session_options.set_provider_selection_policy_delegate(my_delegate) model_compiler = onnxrt.ModelCompiler( session_options, input_model_path, embed_compiled_data_into_model=True, external_initializers_file_path=None, ) model_compiler.compile_to_file(output_model_path) ``` How to raise an exception from the Python EP selection function: ```python # User's custom EP selection function. custom_error_message = "MY ERROR" def my_delegate_that_fails( ep_devices: Sequence[onnxrt.OrtEpDevice], model_metadata: dict[str, str], runtime_metadata: dict[str, str], max_selections: int, ) -> Sequence[onnxrt.OrtEpDevice]: self.assertTrue(len(ep_devices) >= 1) raise ValueError(custom_error_message) sess_options = onnxrt.SessionOptions() sess_options.set_provider_selection_policy_delegate(my_delegate_that_fails) # Create session and expect ORT to raise a Fail exception that contains our message. with self.assertRaises(Fail) as context: onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) self.assertIn(custom_error_message, str(context.exception)) ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Parents
Loading