From e15f515ff297d046ad3a53e0e83d8b8a2a618eb4 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Wed, 1 Nov 2023 22:30:46 +0000 Subject: [PATCH] TLS support for workflows (#632) * Adds tls support for workflows Signed-off-by: Elena Kolevska * Removes default arguments from workflow examples Signed-off-by: Elena Kolevska * Fixes broken demo workflow test Signed-off-by: Elena Kolevska --------- Signed-off-by: Elena Kolevska Co-authored-by: Bernd Verst --- examples/demo_workflow/app.py | 16 ++++----- examples/workflow/monitor.py | 2 +- .../dapr/ext/workflow/dapr_workflow_client.py | 35 ++++++++++--------- .../dapr/ext/workflow/workflow_runtime.py | 11 +++++- 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/examples/demo_workflow/app.py b/examples/demo_workflow/app.py index 892f7b95..edbd1b33 100644 --- a/examples/demo_workflow/app.py +++ b/examples/demo_workflow/app.py @@ -32,11 +32,11 @@ def hello_world_wf(ctx: DaprWorkflowContext, wf_input): print(f'{wf_input}') - yield ctx.call_activity(hello_act, wf_input=1) - yield ctx.call_activity(hello_act, wf_input=10) + yield ctx.call_activity(hello_act, input=1) + yield ctx.call_activity(hello_act, input=10) yield ctx.wait_for_external_event("event1") - yield ctx.call_activity(hello_act, wf_input=100) - yield ctx.call_activity(hello_act, wf_input=1000) + yield ctx.call_activity(hello_act, input=100) + yield ctx.call_activity(hello_act, input=1000) def hello_act(ctx: WorkflowActivityContext, wf_input): @@ -47,9 +47,7 @@ def hello_act(ctx: WorkflowActivityContext, wf_input): def main(): with DaprClient() as d: - host = settings.DAPR_RUNTIME_HOST - port = settings.DAPR_GRPC_PORT - workflow_runtime = WorkflowRuntime(host, port) + workflow_runtime = WorkflowRuntime() workflow_runtime.register_workflow(hello_world_wf) workflow_runtime.register_activity(hello_act) workflow_runtime.start() @@ -107,8 +105,8 @@ def main(): sleep(1) get_response = d.get_workflow(instance_id=instance_id, workflow_component=workflow_component) - print( - f"Get response from {workflow_name} after terminate call: {get_response.runtime_status}") + print(f"Get response from {workflow_name} " + f"after terminate call: {get_response.runtime_status}") # Purge Test d.purge_workflow(instance_id=instance_id, workflow_component=workflow_component) diff --git a/examples/workflow/monitor.py b/examples/workflow/monitor.py index ff80d5d4..d18ad929 100644 --- a/examples/workflow/monitor.py +++ b/examples/workflow/monitor.py @@ -52,7 +52,7 @@ def send_alert(ctx, message: str): if __name__ == '__main__': - workflowRuntime = wf.WorkflowRuntime("localhost", "50001") + workflowRuntime = wf.WorkflowRuntime() workflowRuntime.register_workflow(status_monitor_workflow) workflowRuntime.register_activity(check_status) workflowRuntime.register_activity(send_alert) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 5e94d190..8f9207fa 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -22,9 +22,11 @@ from dapr.ext.workflow.workflow_state import WorkflowState from dapr.ext.workflow.workflow_context import Workflow from dapr.ext.workflow.util import getAddress + +from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings - +from dapr.conf.helpers import GrpcEndpoint T = TypeVar('T') TInput = TypeVar('TInput') @@ -40,16 +42,22 @@ class DaprWorkflowClient: This client is intended to be used by workflow application, not by general purpose application. """ + def __init__(self, host: Optional[str] = None, port: Optional[str] = None): address = getAddress(host, port) + + try: + uri = GrpcEndpoint(address) + except ValueError as error: + raise DaprInternalError(f'{error}') from error + metadata = tuple() if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) - self.__obj = client.TaskHubGrpcClient(host_address=address, metadata=metadata) + self.__obj = client.TaskHubGrpcClient(host_address=uri.endpoint, metadata=metadata, + secure_channel=uri.tls) - def schedule_new_workflow(self, - workflow: Workflow, *, - input: Optional[TInput] = None, + def schedule_new_workflow(self, workflow: Workflow, *, input: Optional[TInput] = None, instance_id: Optional[str] = None, start_at: Optional[datetime] = None) -> str: """Schedules a new workflow instance for execution. @@ -67,9 +75,8 @@ def schedule_new_workflow(self, Returns: The ID of the scheduled workflow instance. """ - return self.__obj.schedule_new_orchestration(workflow.__name__, - input=input, instance_id=instance_id, - start_at=start_at) + return self.__obj.schedule_new_orchestration(workflow.__name__, input=input, + instance_id=instance_id, start_at=start_at) def get_workflow_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[WorkflowState]: @@ -88,8 +95,7 @@ def get_workflow_state(self, instance_id: str, *, state = self.__obj.get_orchestration_state(instance_id, fetch_payloads=fetch_payloads) return WorkflowState(state) if state else None - def wait_for_workflow_start(self, instance_id: str, *, - fetch_payloads: bool = False, + def wait_for_workflow_start(self, instance_id: str, *, fetch_payloads: bool = False, timeout_in_seconds: int = 60) -> Optional[WorkflowState]: """Waits for a workflow to start running and returns a WorkflowState object that contains metadata about the started workflow. @@ -109,13 +115,11 @@ def wait_for_workflow_start(self, instance_id: str, *, WorkflowState record that describes the workflow instance and its execution status. If the specified workflow isn't found, the WorkflowState.Exists value will be false. """ - state = self.__obj.wait_for_orchestration_start(instance_id, - fetch_payloads=fetch_payloads, + state = self.__obj.wait_for_orchestration_start(instance_id, fetch_payloads=fetch_payloads, timeout=timeout_in_seconds) return WorkflowState(state) if state else None - def wait_for_workflow_completion(self, instance_id: str, *, - fetch_payloads: bool = True, + def wait_for_workflow_completion(self, instance_id: str, *, fetch_payloads: bool = True, timeout_in_seconds: int = 60) -> Optional[WorkflowState]: """Waits for a workflow to complete and returns a WorkflowState object that contains metadata about the started instance. @@ -172,8 +176,7 @@ def raise_workflow_event(self, instance_id: str, event_name: str, *, """ return self.__obj.raise_orchestration_event(instance_id, event_name, data=data) - def terminate_workflow(self, instance_id: str, *, - output: Optional[Any] = None): + def terminate_workflow(self, instance_id: str, *, output: Optional[Any] = None): """Terminates a running workflow instance and updates its runtime status to WorkflowRuntimeStatus.Terminated This method internally enqueues a "terminate" message in the task hub. When the task hub worker processes this message, it will update the runtime diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 2cc25ada..a28413df 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -21,8 +21,11 @@ from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.workflow_activity_context import Activity, WorkflowActivityContext from dapr.ext.workflow.util import getAddress + +from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings +from dapr.conf.helpers import GrpcEndpoint T = TypeVar('T') TInput = TypeVar('TInput') @@ -39,7 +42,13 @@ def __init__(self, host: Optional[str] = None, port: Optional[str] = None): metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) address = getAddress(host, port) - self.__worker = worker.TaskHubGrpcWorker(host_address=address, metadata=metadata) + try: + uri = GrpcEndpoint(address) + except ValueError as error: + raise DaprInternalError(f'{error}') from error + + self.__worker = worker.TaskHubGrpcWorker(host_address=uri.endpoint, metadata=metadata, + secure_channel=uri.tls) def register_workflow(self, fn: Workflow): def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None):