diff --git a/st2common/st2common/models/db/workflow.py b/st2common/st2common/models/db/workflow.py index c3f7eab6bb..0937b69ff0 100644 --- a/st2common/st2common/models/db/workflow.py +++ b/st2common/st2common/models/db/workflow.py @@ -25,7 +25,7 @@ from st2common.util import date as date_utils -__all__ = ["WorkflowExecutionDB", "TaskExecutionDB"] +__all__ = ["WorkflowExecutionDB", "TaskExecutionDB", "TaskItemStateDB"] LOG = logging.getLogger(__name__) @@ -85,4 +85,31 @@ class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionField } -MODELS = [WorkflowExecutionDB, TaskExecutionDB] +class TaskItemStateDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin): + """ + Model for storing individual item states for tasks with items (itemized tasks). + This allows efficient storage and retrieval of individual item states without + serializing/deserializing the entire task context for each item. + """ + + RESOURCE_TYPE = types.ResourceType.EXECUTION + + task_execution = me.StringField(required=True) + item_id = me.IntField(required=True) + status = me.StringField(required=True) + result = JSONDictEscapedFieldCompatibilityField() + context = JSONDictEscapedFieldCompatibilityField() + start_timestamp = db_field_types.ComplexDateTimeField( + default=date_utils.get_datetime_utc_now + ) + end_timestamp = db_field_types.ComplexDateTimeField() + + meta = { + "indexes": [ + {"fields": ["task_execution"]}, + {"fields": ["task_execution", "item_id"], "unique": True}, + ] + } + + +MODELS = [WorkflowExecutionDB, TaskExecutionDB, TaskItemStateDB] diff --git a/st2common/st2common/persistence/workflow.py b/st2common/st2common/persistence/workflow.py index 49468bd9ef..b0b3bfe44d 100644 --- a/st2common/st2common/persistence/workflow.py +++ b/st2common/st2common/persistence/workflow.py @@ -21,7 +21,7 @@ from st2common.persistence import base as persistence -__all__ = ["WorkflowExecution", "TaskExecution"] +__all__ = ["WorkflowExecution", "TaskExecution", "TaskItemState"] class WorkflowExecution(persistence.StatusBasedResource): @@ -55,3 +55,43 @@ def _get_impl(cls): @classmethod def delete_by_query(cls, *args, **query): return cls._get_impl().delete_by_query(*args, **query) + + +class TaskItemState(persistence.StatusBasedResource): + impl = db.ChangeRevisionMongoDBAccess(wf_db_models.TaskItemStateDB) + publisher = None + + @classmethod + def _get_impl(cls): + return cls.impl + + @classmethod + def get_by_task_and_item(cls, task_execution_id, item_id): + """ + Retrieve the state record for a specific item in a task execution. + + Args: + task_execution_id: ID of the task execution + item_id: ID of the specific item + + Returns: + TaskItemStateDB: The state record for the specified item + """ + return cls._get_impl().get(task_execution=task_execution_id, item_id=item_id) + + @classmethod + def query_by_task_execution(cls, task_execution_id): + """ + Retrieve all item state records for a task execution. + + Args: + task_execution_id: ID of the task execution + + Returns: + list: List of TaskItemStateDB objects for all items in the task + """ + return cls.query(task_execution=task_execution_id) + + @classmethod + def delete_by_query(cls, *args, **query): + return cls._get_impl().delete_by_query(*args, **query) diff --git a/st2common/st2common/services/workflows.py b/st2common/st2common/services/workflows.py index b84671f8b1..63172d875f 100644 --- a/st2common/st2common/services/workflows.py +++ b/st2common/st2common/services/workflows.py @@ -611,16 +611,28 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): status=statuses.REQUESTED, ) - # Prepare the result format for itemized task execution. - if task_ex_db.itemized: - task_ex_db.result = {"items": [None] * task_ex_db.items_count} - # Insert new record into the database. task_ex_db = wf_db_access.TaskExecution.insert(task_ex_db, publish=False) task_ex_id = str(task_ex_db.id) msg = 'Task execution "%s" created for task "%s", route "%s".' update_progress(wf_ex_db, msg % (task_ex_id, task_id, str(task_route))) + # Prepare state storage for itemized task execution. + if task_ex_db.itemized and task_ex_db.items_count > 0: + # Create a minimal result structure in task_ex_db + task_ex_db.result = {"items_count": task_ex_db.items_count} + wf_db_access.TaskExecution.update(task_ex_db, publish=False) + + # Create separate state records for each item + for i in range(task_ex_db.items_count): + item_state_db = wf_db_models.TaskItemStateDB( + task_execution=str(task_ex_db.id), + item_id=i, + status=statuses.REQUESTED, + context={}, # Will be populated when processing this specific item + ) + wf_db_access.TaskItemState.insert(item_state_db, publish=False) + try: # Return here if no action is specified in task spec. if task_spec.action is None: @@ -723,6 +735,12 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non msg = "Unable to request action execution. Identifier for the item is not provided." raise Exception(msg) + # For itemized tasks, fetch item context from the item state + if task_ex_db.itemized and item_id is not None: + item_state_db = wf_db_access.TaskItemState.get_by_task_and_item( + str(task_ex_db.id), item_id + ) + # Identify the action to execute. action_db = action_utils.get_action_by_ref(ref=action_ref) @@ -759,6 +777,13 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non if item_id is not None: ac_ex_ctx["orquesta"]["item_id"] = item_id + # Update the item state context + item_state_db = wf_db_access.TaskItemState.get_by_task_and_item( + str(task_ex_db.id), item_id + ) + item_state_db.context = ac_ex_ctx + wf_db_access.TaskItemState.update(item_state_db, publish=False) + # Render action execution parameters and setup action execution object. ac_ex_params = param_utils.render_live_params( runner_type_db.runner_parameters or {}, @@ -1256,27 +1281,39 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), item_id) update_progress(wf_ex_db, msg, severity="debug") - task_ex_db.result["items"][item_id] = { - "status": ac_ex_status, - "result": ac_ex_result, - } + # Update the specific item state + item_state_db = wf_db_access.TaskItemState.get_by_task_and_item( + task_ex_id, item_id + ) + item_state_db.status = ac_ex_status + item_state_db.result = ac_ex_result + wf_db_access.TaskItemState.update(item_state_db, publish=False) - item_statuses = [ - item.get("status", statuses.UNSET) if item else statuses.UNSET - for item in task_ex_db.result["items"] - ] + # Check if all items are complete + item_state_dbs = wf_db_access.TaskItemState.query_by_task_execution(task_ex_id) + item_statuses = [item_state_db.status for item_state_db in item_state_dbs] task_completed = all( [status in statuses.COMPLETED_STATUSES for status in item_statuses] ) if task_completed: + # If all items are complete, update the task status new_task_status = ( statuses.SUCCEEDED if all([status == statuses.SUCCEEDED for status in item_statuses]) else statuses.FAILED ) + # Also collect all item results for the main task result + results = [] + for item_state_db in item_state_dbs: + results.append( + {"status": item_state_db.status, "result": item_state_db.result} + ) + + task_ex_db.result = {"items": results} + msg = 'Updating task execution from status "%s" to "%s".' update_progress( wf_ex_db, msg % (task_ex_db.status, new_task_status), severity="debug"