From 4c268fac8288c86d40940ee82d63f8c9ddb28d6c Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 2 Oct 2025 16:47:51 +0200 Subject: [PATCH 01/11] copy files over from prev branch --- .../plugins/dagrun_listener_async.py | 449 ++++++++++++++ .../plugins/dagrun_listener_taskgroup.py | 560 ++++++++++++++++++ 2 files changed, 1009 insertions(+) create mode 100644 src/airflow_provider_aiida/plugins/dagrun_listener_async.py create mode 100644 src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py diff --git a/src/airflow_provider_aiida/plugins/dagrun_listener_async.py b/src/airflow_provider_aiida/plugins/dagrun_listener_async.py new file mode 100644 index 0000000..fdfb6d4 --- /dev/null +++ b/src/airflow_provider_aiida/plugins/dagrun_listener_async.py @@ -0,0 +1,449 @@ +import logging +from airflow.models import DagRun, TaskInstance +from airflow.plugins_manager import AirflowPlugin +from airflow.sdk.definitions.param import Param +from airflow.models import Param as ModelsParam +from airflow.listeners import hookimpl +from aiida import load_profile, orm +from aiida.common.links import LinkType +from pathlib import Path +from typing import Any, Dict, Optional +import json + +load_profile() + +logger = logging.getLogger(__name__) + + +def _param_to_python(param) -> Any: + """ + Convert an Airflow Param object to a Python native value. + + Args: + param: Airflow Param object or any other value + + Returns: + Python native value (int, float, bool, str, dict, list, etc.) + """ + + # Check if it's a Param object + if not isinstance(param, (Param, ModelsParam)): + return param + + # Get the actual value + actual_value = param.value + + # Get schema type if available + schema = getattr(param, "schema", {}) + param_type = schema.get("type", None) + + # Convert based on schema type + if param_type == "integer": + try: + return int(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to int") + return actual_value + + elif param_type == "number": + try: + return float(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to float") + return actual_value + + elif param_type == "boolean": + if isinstance(actual_value, bool): + return actual_value + # Handle string representations + if isinstance(actual_value, str): + return actual_value.lower() in ("true", "1", "yes", "on") + return bool(actual_value) + + elif param_type == "string": + return str(actual_value) + + elif param_type == "object": + # Should already be a dict + return actual_value if isinstance(actual_value, dict) else {} + + elif param_type == "array": + # Should already be a list + return actual_value if isinstance(actual_value, (list, tuple)) else [] + + else: + # No type specified or unknown type - return as-is + return actual_value + + +def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: + """ + Convert a Python value to the appropriate AiiDA Data node. + + Returns None if the value type is not supported or should be skipped. + """ + # First check if it's an Airflow Param and convert it + if isinstance(value, (Param, ModelsParam)): + value = _param_to_python(value) + + # Handle basic types (check bool BEFORE int, since bool is subclass of int) + if isinstance(value, bool): + return orm.Bool(value) + elif isinstance(value, int): + return orm.Int(value) + elif isinstance(value, float): + return orm.Float(value) + elif isinstance(value, str): + return orm.Str(value) + + # Handle collections - store as Dict or List nodes + elif isinstance(value, dict): + return orm.Dict(dict=value) + elif isinstance(value, (list, tuple)): + return orm.List(list=list(value)) + + # Handle Path objects + elif isinstance(value, Path): + return orm.Str(str(value)) + + # For complex objects, try JSON serialization + else: + try: + json_str = json.dumps(value) + return orm.Str(json_str) + except (TypeError, ValueError): + logger.warning( + f"Could not convert value of type {type(value)} to AiiDA node" + ) + return None + + +def _store_params_as_aiida_inputs( + node: orm.Node, params: Dict[str, Any], prefix: str = "" +) -> None: + """ + Store parameters as AiiDA data nodes and link them as inputs. + + Args: + node: The AiiDA node to link inputs to + params: Dictionary of parameters to store + prefix: Optional prefix for link labels (e.g., 'dag_params', 'conf') + """ + for key, value in params.items(): + # Create link label with optional prefix + link_label = f"{prefix}_{key}" if prefix else key + + # Skip None values + if value is None: + continue + + # Convert to AiiDA data node + aiida_data = _convert_to_aiida_data(value) + + if aiida_data is not None: + try: + # Store the data node first + aiida_data.store() + # Then add the link + node.base.links.add_incoming( + aiida_data, link_type=LinkType.INPUT_CALC, link_label=link_label + ) + except ValueError as e: + # Link already exists or other constraint violation + logger.debug(f"Could not link {link_label}: {e}") + + +def should_create_calcjob_node(task_instance: TaskInstance) -> bool: + """ + Determine if a task instance should be converted to a CalcJobNode. + + This is where you can implement your logic for identifying which tasks + should be stored in AiiDA. Options: + 1. Check operator type + 2. Check task_id pattern + 3. Check for specific XCom keys + 4. Use task tags/metadata + """ + # Option 1: Check operator type + if "CalcJobTaskOperator" in task_instance.operator: + return True + + # Option 2: Check task_id pattern + if "calcjob" in task_instance.task_id.lower(): + return True + + # Option 3: Check for marker in task metadata + task = task_instance.task + if hasattr(task, "params") and task.params.get("aiida_store", False): + return True + + return False + + +def _store_task_inputs(node: orm.CalcJobNode, task_instance: TaskInstance) -> None: + """ + Store all inputs for a task: params and conf. + + Args: + node: The CalcJobNode to link inputs to + task_instance: The Airflow task instance + """ + # import ipdb; ipdb.set_trace() + # Store task params (static parameters defined in DAG) + if hasattr(task_instance.task, "params") and task_instance.task.params: + _store_params_as_aiida_inputs( + node, task_instance.task.params, prefix="task_param" + ) + + # Store DAG run conf (dynamic parameters from trigger) + if task_instance.dag_run and task_instance.dag_run.conf: + _store_params_as_aiida_inputs(node, task_instance.dag_run.conf, prefix="conf") + + +def _store_task_outputs(node: orm.CalcJobNode, task_instance: TaskInstance) -> None: + """ + Store all outputs from a task: XCom values. + + Args: + node: The CalcJobNode to link outputs to + task_instance: The Airflow task instance + """ + # TODO: continue here, how to get outputs, XComs? + import ipdb + + ipdb.set_trace() + try: + # Get all XCom keys for this task + xcom_data = task_instance.xcom_pull(task_ids=task_instance.task_id, key=None) + + if not xcom_data: + return + + for key, value in xcom_data.items(): + # Skip return_value or handle it separately if needed + if key == "return_value": + continue + + # Convert to AiiDA data node + aiida_data = _convert_to_aiida_data(value) + if aiida_data: + aiida_data.store() + # Link as output (CREATE link) + aiida_data.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label=f"xcom_{key}" + ) + + except Exception as e: + logger.warning( + f"Could not retrieve XCom outputs for task {task_instance.task_id}: {e}" + ) + + +def _create_calcjob_node_from_task( + task_instance: TaskInstance, parent_workchain_node: orm.WorkChainNode +) -> orm.CalcJobNode: + """ + Create an AiiDA CalcJobNode from an Airflow task instance. + """ + node = orm.CalcJobNode() + node.label = f"airflow_calcjob_{task_instance.task_id}" + node.description = f"CalcJob from Airflow task {task_instance.task_id}" + + # Store Airflow metadata in extras + node.base.extras.set("airflow_dag_id", task_instance.dag_id) + node.base.extras.set("airflow_run_id", task_instance.run_id) + node.base.extras.set("airflow_task_id", task_instance.task_id) + + # Set process metadata + node.set_process_type(f"airflow.{task_instance.operator}") + node.set_process_state("finished") + node.set_exit_status(0 if task_instance.state == "success" else 1) + + # Link to parent WorkChainNode (before storing) + if parent_workchain_node: + node.base.links.add_incoming( + parent_workchain_node, + link_type=LinkType.CALL_CALC, + link_label=task_instance.task_id, + ) + + # Add inputs BEFORE storing the node + _store_task_inputs(node, task_instance) + + # Now store the node (inputs are locked in) + node.store() + + # Outputs can be added after storing + # import ipdb; ipdb.set_trace() + _store_task_outputs(node, task_instance) + + logger.info(f"Created CalcJobNode {node.pk} for task {task_instance.task_id}") + return node + + +def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: + """Check if this DAG should be stored in AiiDA""" + dag_tags = getattr(dag_run.dag, "tags", []) + return "aiida" in dag_tags + + +def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: + """ + Create a WorkChainNode from a running Airflow DAG and store its inputs. + + Returns: + The created and stored WorkChainNode + """ + workchain_node = orm.WorkChainNode() + workchain_node.label = f"airflow_dag_{dag_run.dag_id}" + workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" + + workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) + workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) + + # Store ALL DAG parameters generically + dag_params = getattr(dag_run.dag, "params", {}) + if dag_params: + _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") + + # Store ALL DAG configuration generically + dag_conf = getattr(dag_run, "conf", {}) + if dag_conf: + _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") + + workchain_node.set_process_state("running") + workchain_node.store() + + logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + return workchain_node + + +def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: + """ + Find the WorkChainNode for a completed DAG run and add outputs (CalcJobNodes). + If the WorkChainNode doesn't exist yet (because on_dag_run_running wasn't called), + create it first. + """ + from aiida.orm import QueryBuilder + + # Try to find the WorkChainNode created in on_dag_run_running + qb = QueryBuilder() + qb.append( + orm.WorkChainNode, + filters={"extras.airflow_run_id": dag_run.run_id}, + tag="workchain", + ) + results = qb.all() + + if not results: + # WorkChainNode doesn't exist yet - create it now with inputs + logger.warning( + f"WorkChainNode not found for run_id {dag_run.run_id}. " + f"Creating it now (on_dag_run_running may not have been called)." + ) + workchain_node = _create_workchain_node_with_inputs(dag_run) + else: + workchain_node = results[0][0] + + # Update process state to finished + workchain_node.set_process_state("finished") + workchain_node.set_exit_status(0) + + # Process each task in the DAG to add CalcJobNodes with outputs + task_instances = dag_run.get_task_instances() + for ti in task_instances: + if ti.state == "success" and should_create_calcjob_node(ti): + _create_calcjob_node_from_task(ti, workchain_node) + + logger.info(f"Finalized WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + + +# def _create_workchain_node_from_dag(dag_run: DagRun) -> None: +# """ +# Create a WorkChainNode from a successful Airflow DAG run. +# +# Now completely generic - stores all params and conf without assumptions. +# """ +# # Create the WorkChainNode for the entire DAG +# workchain_node = orm.WorkChainNode() +# workchain_node.label = f"airflow_dag_{dag_run.dag_id}" +# workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" +# +# workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) +# workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) +# +# # Store ALL DAG parameters generically +# dag_params = getattr(dag_run.dag, "params", {}) +# if dag_params: +# _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") +# +# # Store ALL DAG configuration generically +# dag_conf = getattr(dag_run, "conf", {}) +# if dag_conf: +# _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") +# +# workchain_node.set_process_state("finished") +# workchain_node.set_exit_status(0) +# workchain_node.store() +# +# # Process each task in the DAG +# task_instances = dag_run.get_task_instances() +# +# for ti in task_instances: +# # Generic detection: process any successful task +# # You can add filters here if needed (e.g., by operator type) +# if ti.state == "success": +# # Check if this is a CalcJob-like task +# # You might want to add metadata to your tasks to identify them +# if should_create_calcjob_node(ti): +# _create_calcjob_node_from_task(ti, workchain_node) +# +# logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") +# + + +# Airflow Listener Plugin +class AiiDAIntegrationListener: + """Listener that integrates Airflow DAG runs with AiiDA provenance""" + + @hookimpl + def on_dag_run_running(self, dag_run: DagRun, msg: str): + """Called when a DAG run enters the running state.""" + logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") + + if _should_integrate_dag_with_aiida(dag_run): + logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") + try: + _create_workchain_node_with_inputs(dag_run) + except Exception as e: + logger.error( + f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True + ) + + @hookimpl + def on_dag_run_success(self, dag_run: DagRun, msg: str): + """Called when a DAG run completes successfully.""" + logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") + + if _should_integrate_dag_with_aiida(dag_run): + logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") + try: + _finalize_workchain_node_with_outputs(dag_run) + except Exception as e: + logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) + + @hookimpl + def on_dag_run_failed(self, dag_run: DagRun, msg: str): + """Called when a DAG run fails.""" + logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") + # Optionally store failed runs in AiiDA with appropriate exit status + + +# Create listener instance +aiida_listener = AiiDAIntegrationListener() + + +# Plugin registration +class AiiDAIntegrationPlugin(AirflowPlugin): + name = "aiida_integration_plugin" + listeners = [aiida_listener] diff --git a/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py b/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py new file mode 100644 index 0000000..23e5b5b --- /dev/null +++ b/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py @@ -0,0 +1,560 @@ +import logging +from airflow.models import DagRun, TaskInstance +from airflow.plugins_manager import AirflowPlugin +from airflow.sdk.definitions.param import Param +from airflow.models import Param as ModelsParam +from airflow.listeners import hookimpl +from aiida import load_profile, orm +from aiida.common.links import LinkType +from pathlib import Path +from typing import Any, Dict, Optional +import json +import sys +sys.path.append('/home/geiger_j/aiida_projects/aiida-airflow/git-repos/airflow-prototype/dags/') +from calcjob_inheritance import CalcJobTaskGroup + +load_profile() + +logger = logging.getLogger(__name__) + + +def _param_to_python(param) -> Any: + """ + Convert an Airflow Param object to a Python native value. + + Args: + param: Airflow Param object or any other value + + Returns: + Python native value (int, float, bool, str, dict, list, etc.) + """ + + # Check if it's a Param object + if not isinstance(param, (Param, ModelsParam)): + return param + + # Get the actual value + actual_value = param.value + + # Get schema type if available + schema = getattr(param, "schema", {}) + param_type = schema.get("type", None) + + # Convert based on schema type + if param_type == "integer": + try: + return int(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to int") + return actual_value + + elif param_type == "number": + try: + return float(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to float") + return actual_value + + elif param_type == "boolean": + if isinstance(actual_value, bool): + return actual_value + # Handle string representations + if isinstance(actual_value, str): + return actual_value.lower() in ("true", "1", "yes", "on") + return bool(actual_value) + + elif param_type == "string": + return str(actual_value) + + elif param_type == "object": + # Should already be a dict + return actual_value if isinstance(actual_value, dict) else {} + + elif param_type == "array": + # Should already be a list + return actual_value if isinstance(actual_value, (list, tuple)) else [] + + else: + # No type specified or unknown type - return as-is + return actual_value + + +def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: + """ + Convert a Python value to the appropriate AiiDA Data node. + + Returns None if the value type is not supported or should be skipped. + """ + # First check if it's an Airflow Param and convert it + if isinstance(value, (Param, ModelsParam)): + value = _param_to_python(value) + + # Handle basic types (check bool BEFORE int, since bool is subclass of int) + if isinstance(value, bool): + return orm.Bool(value) + elif isinstance(value, int): + return orm.Int(value) + elif isinstance(value, float): + return orm.Float(value) + elif isinstance(value, str): + return orm.Str(value) + + # Handle collections - store as Dict or List nodes + elif isinstance(value, dict): + return orm.Dict(dict=value) + elif isinstance(value, (list, tuple)): + return orm.List(list=list(value)) + + # Handle Path objects + elif isinstance(value, Path): + return orm.Str(str(value)) + + # For complex objects, try JSON serialization + else: + try: + json_str = json.dumps(value) + return orm.Str(json_str) + except (TypeError, ValueError): + logger.warning( + f"Could not convert value of type {type(value)} to AiiDA node" + ) + return None + + +def _store_params_as_aiida_inputs( + node: orm.Node, params: Dict[str, Any], prefix: str = "" +) -> None: + """ + Store parameters as AiiDA data nodes and link them as inputs. + + Args: + node: The AiiDA node to link inputs to + params: Dictionary of parameters to store + prefix: Optional prefix for link labels (e.g., 'dag_params', 'conf') + """ + for key, value in params.items(): + # Create link label with optional prefix + link_label = f"{prefix}_{key}" if prefix else key + + # Skip None values + if value is None: + continue + + # Convert to AiiDA data node + aiida_data = _convert_to_aiida_data(value) + if isinstance(node, orm.WorkflowNode): + link_type = LinkType.INPUT_WORK + elif isinstance(node, orm.CalculationNode): + link_type = LinkType.INPUT_CALC + + if aiida_data is not None: + try: + # Store the data node first + aiida_data.store() + # Then add the link + node.base.links.add_incoming( + aiida_data, link_type=link_type, link_label=link_label + ) + except ValueError as e: + # Link already exists or other constraint violation + logger.debug(f"Could not link {link_label}: {e}") + + +def should_create_calcjob_node_for_taskgroup(task_instance: TaskInstance) -> bool: + """ + Determine if a task instance is part of a CalcJobTaskGroup. + + This checks if the task is the "parse" task of a CalcJobTaskGroup, + which signals completion of the entire group. + + Args: + task_instance: Airflow task instance + + Returns: + bool: True if this is a parse task from a CalcJobTaskGroup + """ + # Check if task_id indicates it's a parse task in a task group + if ".parse" in task_instance.task_id: + # Verify parent group exists and has the expected structure + group_id = task_instance.task_id.rsplit(".parse", 1)[0] + + # Check if this is likely a CalcJobTaskGroup by looking for sibling tasks + dag_run = task_instance.dag_run + if dag_run: + task_instances = dag_run.get_task_instances() + # Look for the prepare task in the same group + for ti in task_instances: + if ti.task_id == f"{group_id}.prepare": + return True + + return False + + +def _get_taskgroup_id_from_parse_task(task_instance: TaskInstance) -> str: + """Extract the task group ID from a parse task's task_id""" + return task_instance.task_id.rsplit(".parse", 1)[0] + + +def _store_taskgroup_inputs( + node: orm.CalcJobNode, task_instance: TaskInstance, dag_run: DagRun +) -> None: + """ + Store all inputs for a CalcJobTaskGroup. + + Inputs come from: + 1. The prepare task's XCom outputs (to_upload_files, submission_script, to_receive_files) + 2. The CalcJobTaskGroup instance's parameters (x, y, sleep, etc.) + 3. DAG-level params and conf + + Args: + node: The CalcJobNode to link inputs to + task_instance: The parse task instance (end of the group) + dag_run: The DAG run containing the task + """ + group_id = _get_taskgroup_id_from_parse_task(task_instance) + prepare_task_id = f"{group_id}.prepare" + + # Get the prepare task instance to access its XCom data + prepare_ti = None + for ti in dag_run.get_task_instances(): + if ti.task_id == prepare_task_id: + prepare_ti = ti + break + + if not prepare_ti: + logger.warning(f"Could not find prepare task {prepare_task_id}") + return + + # Store prepare task outputs as inputs to the CalcJob + try: + to_upload_files = task_instance.xcom_pull( + task_ids=prepare_task_id, key="to_upload_files" + ) + if to_upload_files: + aiida_data = _convert_to_aiida_data(to_upload_files) + if aiida_data: + aiida_data.store() + node.base.links.add_incoming( + aiida_data, + link_type=LinkType.INPUT_CALC, + link_label="to_upload_files", + ) + except Exception as e: + logger.debug(f"Could not store to_upload_files: {e}") + + try: + submission_script = task_instance.xcom_pull( + task_ids=prepare_task_id, key="submission_script" + ) + if submission_script: + aiida_data = _convert_to_aiida_data(submission_script) + if aiida_data: + aiida_data.store() + node.base.links.add_incoming( + aiida_data, + link_type=LinkType.INPUT_CALC, + link_label="submission_script", + ) + except Exception as e: + logger.debug(f"Could not store submission_script: {e}") + + try: + to_receive_files = task_instance.xcom_pull( + task_ids=prepare_task_id, key="to_receive_files" + ) + if to_receive_files: + aiida_data = _convert_to_aiida_data(to_receive_files) + if aiida_data: + aiida_data.store() + node.base.links.add_incoming( + aiida_data, + link_type=LinkType.INPUT_CALC, + link_label="to_receive_files", + ) + except Exception as e: + logger.debug(f"Could not store to_receive_files: {e}") + + # Store DAG-level params and conf + if dag_run.conf: + _store_params_as_aiida_inputs(node, dag_run.conf, prefix="conf") + + dag_params = getattr(dag_run.dag, "params", {}) + if dag_params: + _store_params_as_aiida_inputs(node, dag_params, prefix="dag_param") + + +def _store_taskgroup_outputs( + node: orm.CalcJobNode, task_instance: TaskInstance +) -> None: + """ + Store all outputs from a CalcJobTaskGroup. + + Outputs come from the parse task's XCom data (final_result). + + Args: + node: The CalcJobNode to link outputs to + task_instance: The parse task instance + """ + # NOTE: Aren't all the constructs we create taskgroups? How to differentiate between DAG and calcjob + try: + # Get the final_result from the parse task + final_result = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="final_result" + ) + + if final_result: + # Handle tuple format (exit_status, results) from AddJobTaskGroup + if isinstance(final_result, tuple) and len(final_result) == 2: + exit_status, results = final_result + + # Store exit status + exit_status_node = orm.Int(exit_status) + exit_status_node.store() + exit_status_node.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label="exit_status" + ) + + # Store results dict + if results: + for key, value in results.items(): + aiida_data = _convert_to_aiida_data(value) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, + link_type=LinkType.CREATE, + link_label=f"result_{key}", + ) + + # Handle dict format from MultiplyJobTaskGroup + elif isinstance(final_result, dict): + for key, value in final_result.items(): + aiida_data = _convert_to_aiida_data(value) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label=f"result_{key}" + ) + + # Handle other formats + else: + aiida_data = _convert_to_aiida_data(final_result) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label="final_result" + ) + + except Exception as e: + logger.warning( + f"Could not retrieve outputs for task {task_instance.task_id}: {e}" + ) + + +def _create_calcjob_node_from_taskgroup( + task_instance: TaskInstance, + parent_workchain_node: orm.WorkChainNode, + dag_run: DagRun, +) -> orm.CalcJobNode: + """ + Create an AiiDA CalcJobNode from a CalcJobTaskGroup (represented by its parse task). + + Args: + task_instance: The parse task instance (end of the TaskGroup) + parent_workchain_node: The parent WorkChainNode for the DAG + dag_run: The DAG run + + Returns: + The created and stored CalcJobNode + """ + # import ipdb; ipdb.set_trace() + # NOTE: locals() + # {'dag_run': , + # 'ipdb': , + # 'parent_workchain_node': , + # 'task_instance': } + + # NOTE: Should one pass a `task_instance` here, or shouldn't it be a taskgroup. + # possibly apply this function to every task group, and have special handling only when it is calcjob. if not, then store things in the "workchain way" + + # if isinstance(CalcJobTaskGroup): + + group_id = _get_taskgroup_id_from_parse_task(task_instance) + + node = orm.CalcJobNode() + node.label = f"airflow_calcjob_group_{group_id}" + node.description = f"CalcJob from Airflow TaskGroup {group_id}" + + # Store Airflow metadata in extras + node.base.extras.set("airflow_dag_id", task_instance.dag_id) + node.base.extras.set("airflow_run_id", task_instance.run_id) + node.base.extras.set("airflow_task_group_id", group_id) + + # Set process metadata + node.set_process_type(f"airflow.CalcJobTaskGroup") + node.set_process_state("finished") + + # Determine exit status from parse task result + exit_status = 0 + try: + final_result = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="final_result" + ) + if isinstance(final_result, tuple) and len(final_result) == 2: + exit_status = final_result[0] + except Exception: + pass + + node.set_exit_status(exit_status if task_instance.state == "success" else 1) + + # Link to parent WorkChainNode (before storing) + if parent_workchain_node: + node.base.links.add_incoming( + parent_workchain_node, + link_type=LinkType.CALL_CALC, + link_label=group_id, + ) + + # Add inputs BEFORE storing the node + # TODO: Computer is not an input, but an attribute of the calcjob, or, rather, `metadata.computer` + _store_taskgroup_inputs(node, task_instance, dag_run) + + # Now store the node (inputs are locked in) + node.store() + # NOTE: final_result -> {'__classname__': 'builtins.tuple', '__version__': 1, '__data__': [0, {'result.out': 12}]} + # TODO: + + # Outputs can be added after storing + _store_taskgroup_outputs(node, task_instance) + + logger.info(f"Created CalcJobNode {node.pk} for TaskGroup {group_id}") + return node + + +def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: + """Check if this DAG should be stored in AiiDA""" + dag_tags = getattr(dag_run.dag, "tags", []) + # Look for tags that indicate this is a CalcJob workflow + return any(tag in dag_tags for tag in ["aiida", "calcjob", "taskgroup"]) + + +def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: + """ + Create a WorkChainNode from a running Airflow DAG and store its inputs. + + Returns: + The created and stored WorkChainNode + """ + workchain_node = orm.WorkChainNode() + workchain_node.label = f"airflow_dag_{dag_run.dag_id}" + workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" + + workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) + workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) + + # Store ALL DAG parameters generically + dag_params = getattr(dag_run.dag, "params", {}) + if dag_params: + # TODO: These are being set as inputs of the workflow, even though they are inputs of the calculations, and are just passed through the dag + # {'machine': 'localhost', 'local_workdir': '/home/geiger_j/airflow/storage/local_workdir', 'remote_workdir': '/home/geiger_j/airflow/storage/remote_workdir'} + # import ipdb; ipdb.set_trace() + _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") + + # Store ALL DAG configuration generically + dag_conf = getattr(dag_run, "conf", {}) + if dag_conf: + _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") + + workchain_node.set_process_state("running") + workchain_node.store() + + logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + return workchain_node + + +def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: + """ + Find the WorkChainNode for a completed DAG run and add outputs (CalcJobNodes from TaskGroups). + If the WorkChainNode doesn't exist yet, create it first. + """ + # NOTE: Why do i need to query for it? it should be directly accessible, no? -> it's bc i create it in the other function? + from aiida.orm import QueryBuilder + + # Try to find the WorkChainNode created in on_dag_run_running + qb = QueryBuilder() + qb.append( + orm.WorkChainNode, + filters={"extras.airflow_run_id": dag_run.run_id}, + tag="workchain", + ) + results = qb.all() + + if not results: + # WorkChainNode doesn't exist yet - create it now with inputs + logger.warning( + f"WorkChainNode not found for run_id {dag_run.run_id}. " + f"Creating it now (on_dag_run_running may not have been called)." + ) + workchain_node = _create_workchain_node_with_inputs(dag_run) + else: + workchain_node = results[0][0] + + # Update process state to finished + workchain_node.set_process_state("finished") + workchain_node.set_exit_status(0) + + # Process each task in the DAG to find CalcJobTaskGroup parse tasks + task_instances = dag_run.get_task_instances() + for ti in task_instances: + if ti.state == "success" and should_create_calcjob_node_for_taskgroup(ti): + _create_calcjob_node_from_taskgroup(ti, workchain_node, dag_run) + + logger.info(f"Finalized WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + + +# Airflow Listener Plugin +class AiiDATaskGroupIntegrationListener: + """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" + + @hookimpl + def on_dag_run_running(self, dag_run: DagRun, msg: str): + """Called when a DAG run enters the running state.""" + logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") + + if _should_integrate_dag_with_aiida(dag_run): + logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") + try: + _create_workchain_node_with_inputs(dag_run) + except Exception as e: + logger.error( + f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True + ) + + @hookimpl + def on_dag_run_success(self, dag_run: DagRun, msg: str): + """Called when a DAG run completes successfully.""" + logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") + + if _should_integrate_dag_with_aiida(dag_run): + logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") + try: + _finalize_workchain_node_with_outputs(dag_run) + except Exception as e: + logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) + + @hookimpl + def on_dag_run_failed(self, dag_run: DagRun, msg: str): + """Called when a DAG run fails.""" + logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") + # Optionally store failed runs in AiiDA with appropriate exit status + + +# Create listener instance +aiida_taskgroup_listener = AiiDATaskGroupIntegrationListener() + + +# Plugin registration +class AiiDATaskGroupIntegrationPlugin(AirflowPlugin): + name = "aiida_taskgroup_integration_plugin" + listeners = [aiida_taskgroup_listener] + From 4a82878094b4ee9f84f23669d122fd0dd4d3988b Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 2 Oct 2025 17:00:08 +0200 Subject: [PATCH 02/11] backup for prev listener files --- .../{dagrun_listener_async.py => dagrun_listener_async.py.bak} | 0 ...run_listener_taskgroup.py => dagrun_listener_taskgroup.py.bak} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/airflow_provider_aiida/plugins/{dagrun_listener_async.py => dagrun_listener_async.py.bak} (100%) rename src/airflow_provider_aiida/plugins/{dagrun_listener_taskgroup.py => dagrun_listener_taskgroup.py.bak} (100%) diff --git a/src/airflow_provider_aiida/plugins/dagrun_listener_async.py b/src/airflow_provider_aiida/plugins/dagrun_listener_async.py.bak similarity index 100% rename from src/airflow_provider_aiida/plugins/dagrun_listener_async.py rename to src/airflow_provider_aiida/plugins/dagrun_listener_async.py.bak diff --git a/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py b/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py.bak similarity index 100% rename from src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py rename to src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py.bak From 1889b02fcae6ec7543af9c2b85b202f37013e473 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 2 Oct 2025 17:04:02 +0200 Subject: [PATCH 03/11] integrate into aiida_dag_run_listener.py --- .../plugins/aiida_dag_run_listener.py | 643 ++++++++++++++---- .../plugins/dagrun_listener_async.py.bak | 449 ------------ .../plugins/dagrun_listener_taskgroup.py.bak | 560 --------------- 3 files changed, 498 insertions(+), 1154 deletions(-) delete mode 100644 src/airflow_provider_aiida/plugins/dagrun_listener_async.py.bak delete mode 100644 src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py.bak diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index 8ce15e3..f68caaf 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -1,189 +1,542 @@ -import sqlite3 -import os import logging -from datetime import datetime +import sys +from pathlib import Path +from typing import Any, Dict, Optional + +from airflow.models import DagRun, TaskInstance from airflow.plugins_manager import AirflowPlugin +from airflow.sdk.definitions.param import Param +from airflow.models import Param as ModelsParam from airflow.listeners import hookimpl -from airflow.models import DagRun, XCom -from airflow.utils.state import DagRunState +from aiida import load_profile, orm +from aiida.common.links import LinkType +import json + +# Add dags directory to path for CalcJobTaskGroup import +sys.path.append('/home/geiger_j/aiida_projects/aiida-airflow/git-repos/airflow-prototype/dags/') +from calcjob_inheritance import CalcJobTaskGroup + +load_profile() logger = logging.getLogger(__name__) -# Database path -DB_PATH = os.path.join(os.path.dirname(__file__), 'dagrun_tracking.db') -def _init_database(): - """Initialize the SQLite database and create the dagrun table if it doesn't exist.""" +def _param_to_python(param) -> Any: + """ + Convert an Airflow Param object to a Python native value. + + Args: + param: Airflow Param object or any other value + + Returns: + Python native value (int, float, bool, str, dict, list, etc.) + """ + + # Check if it's a Param object + if not isinstance(param, (Param, ModelsParam)): + return param + + # Get the actual value + actual_value = param.value + + # Get schema type if available + schema = getattr(param, "schema", {}) + param_type = schema.get("type", None) + + # Convert based on schema type + if param_type == "integer": + try: + return int(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to int") + return actual_value + + elif param_type == "number": + try: + return float(actual_value) + except (ValueError, TypeError): + logger.warning(f"Could not convert Param value '{actual_value}' to float") + return actual_value + + elif param_type == "boolean": + if isinstance(actual_value, bool): + return actual_value + # Handle string representations + if isinstance(actual_value, str): + return actual_value.lower() in ("true", "1", "yes", "on") + return bool(actual_value) + + elif param_type == "string": + return str(actual_value) + + elif param_type == "object": + # Should already be a dict + return actual_value if isinstance(actual_value, dict) else {} + + elif param_type == "array": + # Should already be a list + return actual_value if isinstance(actual_value, (list, tuple)) else [] + + else: + # No type specified or unknown type - return as-is + return actual_value + + +def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: + """ + Convert a Python value to the appropriate AiiDA Data node. + + Returns None if the value type is not supported or should be skipped. + """ + # First check if it's an Airflow Param and convert it + if isinstance(value, (Param, ModelsParam)): + value = _param_to_python(value) + + # Handle basic types (check bool BEFORE int, since bool is subclass of int) + if isinstance(value, bool): + return orm.Bool(value) + elif isinstance(value, int): + return orm.Int(value) + elif isinstance(value, float): + return orm.Float(value) + elif isinstance(value, str): + return orm.Str(value) + + # Handle collections - store as Dict or List nodes + elif isinstance(value, dict): + return orm.Dict(dict=value) + elif isinstance(value, (list, tuple)): + return orm.List(list=list(value)) + + # Handle Path objects + elif isinstance(value, Path): + return orm.Str(str(value)) + + # For complex objects, try JSON serialization + else: + try: + json_str = json.dumps(value) + return orm.Str(json_str) + except (TypeError, ValueError): + logger.warning( + f"Could not convert value of type {type(value)} to AiiDA node" + ) + return None + + +def _store_params_as_aiida_inputs( + node: orm.Node, params: Dict[str, Any], prefix: str = "" +) -> None: + """ + Store parameters as AiiDA data nodes and link them as inputs. + + Args: + node: The AiiDA node to link inputs to + params: Dictionary of parameters to store + prefix: Optional prefix for link labels (e.g., 'dag_params', 'conf') + """ + for key, value in params.items(): + # Create link label with optional prefix + link_label = f"{prefix}_{key}" if prefix else key + + # Skip None values + if value is None: + continue + + # Convert to AiiDA data node + aiida_data = _convert_to_aiida_data(value) + if isinstance(node, orm.WorkflowNode): + link_type = LinkType.INPUT_WORK + elif isinstance(node, orm.CalculationNode): + link_type = LinkType.INPUT_CALC + + if aiida_data is not None: + try: + # Store the data node first + aiida_data.store() + # Then add the link + node.base.links.add_incoming( + aiida_data, link_type=link_type, link_label=link_label + ) + except ValueError as e: + # Link already exists or other constraint violation + logger.debug(f"Could not link {link_label}: {e}") + + +def should_create_calcjob_node_for_taskgroup(task_instance: TaskInstance) -> bool: + """ + Determine if a task instance is part of a CalcJobTaskGroup. + + This checks if the task is the "parse" task of a CalcJobTaskGroup, + which signals completion of the entire group. + + Args: + task_instance: Airflow task instance + + Returns: + bool: True if this is a parse task from a CalcJobTaskGroup + """ + # Check if task_id indicates it's a parse task in a task group + if ".parse" in task_instance.task_id: + # Verify parent group exists and has the expected structure + group_id = task_instance.task_id.rsplit(".parse", 1)[0] + + # Check if this is likely a CalcJobTaskGroup by looking for sibling tasks + dag_run = task_instance.dag_run + if dag_run: + task_instances = dag_run.get_task_instances() + # Look for the prepare task in the same group + for ti in task_instances: + if ti.task_id == f"{group_id}.prepare": + return True + + return False + + +def _get_taskgroup_id_from_parse_task(task_instance: TaskInstance) -> str: + """Extract the task group ID from a parse task's task_id""" + return task_instance.task_id.rsplit(".parse", 1)[0] + + +def _store_taskgroup_inputs( + node: orm.CalcJobNode, task_instance: TaskInstance, dag_run: DagRun +) -> None: + """ + Store all inputs for a CalcJobTaskGroup. + + Inputs come from: + 1. The prepare task's XCom outputs (to_upload_files, submission_script, to_receive_files) + 2. The CalcJobTaskGroup instance's parameters (x, y, sleep, etc.) + 3. DAG-level params and conf + + Args: + node: The CalcJobNode to link inputs to + task_instance: The parse task instance (end of the group) + dag_run: The DAG run containing the task + """ + group_id = _get_taskgroup_id_from_parse_task(task_instance) + prepare_task_id = f"{group_id}.prepare" + + # Get the prepare task instance to access its XCom data + prepare_ti = None + for ti in dag_run.get_task_instances(): + if ti.task_id == prepare_task_id: + prepare_ti = ti + break + + if not prepare_ti: + logger.warning(f"Could not find prepare task {prepare_task_id}") + return + + # Store prepare task outputs as inputs to the CalcJob try: - with sqlite3.connect(DB_PATH) as conn: - cursor = conn.cursor() - cursor.execute(''' - CREATE TABLE IF NOT EXISTS dagrun_events ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - dag_id TEXT NOT NULL, - run_id TEXT NOT NULL, - run_type TEXT, - state TEXT, - execution_date TEXT, - start_date TEXT, - end_date TEXT, - external_trigger BOOLEAN, - conf TEXT, - dag_output TEXT, - event_type TEXT NOT NULL, - event_timestamp TEXT NOT NULL, - created_at TEXT DEFAULT CURRENT_TIMESTAMP + to_upload_files = task_instance.xcom_pull( + task_ids=prepare_task_id, key="to_upload_files" + ) + if to_upload_files: + aiida_data = _convert_to_aiida_data(to_upload_files) + if aiida_data: + aiida_data.store() + node.base.links.add_incoming( + aiida_data, + link_type=LinkType.INPUT_CALC, + link_label="to_upload_files", ) - ''') - conn.commit() - logger.info(f"DagRun tracking database initialized at {DB_PATH}") except Exception as e: - logger.error(f"Failed to initialize database: {e}") + logger.debug(f"Could not store to_upload_files: {e}") -def _get_dag_output(dagrun: DagRun) -> str: - """Retrieve DAG output from XCom if available.""" try: - # Look for XCom with key 'dag_output' from any task in this DAG run - from airflow.models import XCom - from airflow.utils.session import provide_session - - @provide_session - def _query_xcom(session=None): - xcom_value = session.query(XCom).filter( - XCom.dag_id == dagrun.dag_id, - XCom.run_id == dagrun.run_id, - XCom.key == 'dag_output' - ).first() - return xcom_value.value if xcom_value else None - - output = _query_xcom() - logger.info(f"[DEBUG] Retrieved DAG output: {output}") - return str(output) if output else '{}' + submission_script = task_instance.xcom_pull( + task_ids=prepare_task_id, key="submission_script" + ) + if submission_script: + aiida_data = _convert_to_aiida_data(submission_script) + if aiida_data: + aiida_data.store() + node.base.links.add_incoming( + aiida_data, + link_type=LinkType.INPUT_CALC, + link_label="submission_script", + ) + except Exception as e: + logger.debug(f"Could not store submission_script: {e}") + try: + to_receive_files = task_instance.xcom_pull( + task_ids=prepare_task_id, key="to_receive_files" + ) + if to_receive_files: + aiida_data = _convert_to_aiida_data(to_receive_files) + if aiida_data: + aiida_data.store() + node.base.links.add_incoming( + aiida_data, + link_type=LinkType.INPUT_CALC, + link_label="to_receive_files", + ) except Exception as e: - logger.warning(f"Failed to retrieve DAG output: {e}") - return '{}' + logger.debug(f"Could not store to_receive_files: {e}") + + # Store DAG-level params and conf + if dag_run.conf: + _store_params_as_aiida_inputs(node, dag_run.conf, prefix="conf") + + dag_params = getattr(dag_run.dag, "params", {}) + if dag_params: + _store_params_as_aiida_inputs(node, dag_params, prefix="dag_param") + -def _get_dag_output_safe(dag_id: str, run_id: str) -> str: - """Safely retrieve DAG output from XCom using dag_id and run_id strings.""" +def _store_taskgroup_outputs( + node: orm.CalcJobNode, task_instance: TaskInstance +) -> None: + """ + Store all outputs from a CalcJobTaskGroup. + + Outputs come from the parse task's XCom data (final_result). + + Args: + node: The CalcJobNode to link outputs to + task_instance: The parse task instance + """ try: - # Look for XCom with key 'dag_output' from any task in this DAG run - from airflow.models import XCom - from airflow.utils.session import provide_session - - @provide_session - def _query_xcom(session=None): - xcom_value = session.query(XCom).filter( - XCom.dag_id == dag_id, - XCom.run_id == run_id, - XCom.key == 'dag_output' - ).first() - return xcom_value.value if xcom_value else None - - output = _query_xcom() - logger.info(f"[DEBUG] Retrieved DAG output: {output}") - return str(output) if output else '{}' + # Get the final_result from the parse task + final_result = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="final_result" + ) + + if final_result: + # Handle tuple format (exit_status, results) from AddJobTaskGroup + if isinstance(final_result, tuple) and len(final_result) == 2: + exit_status, results = final_result + + # Store exit status + exit_status_node = orm.Int(exit_status) + exit_status_node.store() + exit_status_node.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label="exit_status" + ) + + # Store results dict + if results: + for key, value in results.items(): + aiida_data = _convert_to_aiida_data(value) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, + link_type=LinkType.CREATE, + link_label=f"result_{key}", + ) + + # Handle dict format from MultiplyJobTaskGroup + elif isinstance(final_result, dict): + for key, value in final_result.items(): + aiida_data = _convert_to_aiida_data(value) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label=f"result_{key}" + ) + + # Handle other formats + else: + aiida_data = _convert_to_aiida_data(final_result) + if aiida_data: + aiida_data.store() + aiida_data.base.links.add_incoming( + node, link_type=LinkType.CREATE, link_label="final_result" + ) except Exception as e: - logger.warning(f"Failed to retrieve DAG output: {e}") - return '{}' + logger.warning( + f"Could not retrieve outputs for task {task_instance.task_id}: {e}" + ) + + +def _create_calcjob_node_from_taskgroup( + task_instance: TaskInstance, + parent_workchain_node: orm.WorkChainNode, + dag_run: DagRun, +) -> orm.CalcJobNode: + """ + Create an AiiDA CalcJobNode from a CalcJobTaskGroup (represented by its parse task). + + Args: + task_instance: The parse task instance (end of the TaskGroup) + parent_workchain_node: The parent WorkChainNode for the DAG + dag_run: The DAG run + + Returns: + The created and stored CalcJobNode + """ + group_id = _get_taskgroup_id_from_parse_task(task_instance) -def _store_dagrun_event(dagrun: DagRun, event_type: str): - """Store dagrun event information to SQLite database.""" + node = orm.CalcJobNode() + node.label = f"airflow_calcjob_group_{group_id}" + node.description = f"CalcJob from Airflow TaskGroup {group_id}" + + # Store Airflow metadata in extras + node.base.extras.set("airflow_dag_id", task_instance.dag_id) + node.base.extras.set("airflow_run_id", task_instance.run_id) + node.base.extras.set("airflow_task_group_id", group_id) + + # Set process metadata + node.set_process_type(f"airflow.CalcJobTaskGroup") + node.set_process_state("finished") + + # Determine exit status from parse task result + exit_status = 0 try: - # IMPORTANT: Extract all needed attributes FIRST to avoid DetachedInstanceError - # This must happen before any database operations - dag_id = dagrun.dag_id - run_id = dagrun.run_id - run_type_str = dagrun.run_type.value if hasattr(dagrun.run_type, 'value') else str(dagrun.run_type) if dagrun.run_type else None - state_str = dagrun.state.value if hasattr(dagrun.state, 'value') else str(dagrun.state) if dagrun.state else None - execution_date = getattr(dagrun, 'execution_date', None) - start_date = getattr(dagrun, 'start_date', None) - end_date = getattr(dagrun, 'end_date', None) - external_trigger = getattr(dagrun, 'external_trigger', False) - conf = getattr(dagrun, 'conf', {}) - - logger.info(f"[DEBUG] Starting to store {event_type} event") - logger.info(f"[DEBUG] DB_PATH: {DB_PATH}") - logger.info(f"[DEBUG] dag_id: {dag_id}") - logger.info(f"[DEBUG] run_id: {run_id}") - - with sqlite3.connect(DB_PATH) as conn: - cursor = conn.cursor() - - # Test if table exists - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='dagrun_events'") - table_exists = cursor.fetchone() - logger.info(f"[DEBUG] Table exists: {table_exists is not None}") - - # Get DAG output for completed DAGs (pass extracted values instead of dagrun object) - dag_output = _get_dag_output_safe(dag_id, run_id) if event_type in ['success', 'failed'] else '{}' - - data_tuple = ( - dag_id, - run_id, - run_type_str, - state_str, - execution_date.isoformat() if execution_date else None, - start_date.isoformat() if start_date else None, - end_date.isoformat() if end_date else None, - external_trigger, - str(conf) if conf else '{}', - dag_output, - event_type, - datetime.now().isoformat() - ) + final_result = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="final_result" + ) + if isinstance(final_result, tuple) and len(final_result) == 2: + exit_status = final_result[0] + except Exception: + pass - logger.info(f"[DEBUG] Data tuple: {data_tuple}") + node.set_exit_status(exit_status if task_instance.state == "success" else 1) - cursor.execute(''' - INSERT INTO dagrun_events ( - dag_id, run_id, run_type, state, execution_date, - start_date, end_date, external_trigger, conf, dag_output, - event_type, event_timestamp, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', data_tuple + (datetime.now().strftime('%Y-%m-%d %H:%M:%S'),)) + # Link to parent WorkChainNode (before storing) + if parent_workchain_node: + node.base.links.add_incoming( + parent_workchain_node, + link_type=LinkType.CALL_CALC, + link_label=group_id, + ) - conn.commit() - logger.info(f"[SUCCESS] Stored {event_type} event for DAG run {dag_id}/{run_id}") + # Add inputs BEFORE storing the node + _store_taskgroup_inputs(node, task_instance, dag_run) - # Verify insertion - cursor.execute("SELECT COUNT(*) FROM dagrun_events") - count = cursor.fetchone()[0] - logger.info(f"[DEBUG] Total events in DB: {count}") + # Now store the node (inputs are locked in) + node.store() - except Exception as e: - logger.error(f"[ERROR] Failed to store dagrun event: {e}") - import traceback - logger.error(f"[ERROR] Traceback: {traceback.format_exc()}") + # Outputs can be added after storing + _store_taskgroup_outputs(node, task_instance) -# Initialize database on import -_init_database() + logger.info(f"Created CalcJobNode {node.pk} for TaskGroup {group_id}") + return node -class DagRunListener: - """Class-based DAG run listener.""" + +def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: + """Check if this DAG should be stored in AiiDA""" + dag_tags = getattr(dag_run.dag, "tags", []) + # Look for tags that indicate this is a CalcJob workflow + return any(tag in dag_tags for tag in ["aiida", "calcjob", "taskgroup"]) + + +def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: + """ + Create a WorkChainNode from a running Airflow DAG and store its inputs. + + Returns: + The created and stored WorkChainNode + """ + workchain_node = orm.WorkChainNode() + workchain_node.label = f"airflow_dag_{dag_run.dag_id}" + workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" + + workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) + workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) + + # Store ALL DAG parameters generically + dag_params = getattr(dag_run.dag, "params", {}) + if dag_params: + _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") + + # Store ALL DAG configuration generically + dag_conf = getattr(dag_run, "conf", {}) + if dag_conf: + _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") + + workchain_node.set_process_state("running") + workchain_node.store() + + logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + return workchain_node + + +def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: + """ + Find the WorkChainNode for a completed DAG run and add outputs (CalcJobNodes from TaskGroups). + If the WorkChainNode doesn't exist yet, create it first. + """ + from aiida.orm import QueryBuilder + + # Try to find the WorkChainNode created in on_dag_run_running + qb = QueryBuilder() + qb.append( + orm.WorkChainNode, + filters={"extras.airflow_run_id": dag_run.run_id}, + tag="workchain", + ) + results = qb.all() + + if not results: + # WorkChainNode doesn't exist yet - create it now with inputs + logger.warning( + f"WorkChainNode not found for run_id {dag_run.run_id}. " + f"Creating it now (on_dag_run_running may not have been called)." + ) + workchain_node = _create_workchain_node_with_inputs(dag_run) + else: + workchain_node = results[0][0] + + # Update process state to finished + workchain_node.set_process_state("finished") + workchain_node.set_exit_status(0) + + # Process each task in the DAG to find CalcJobTaskGroup parse tasks + task_instances = dag_run.get_task_instances() + for ti in task_instances: + if ti.state == "success" and should_create_calcjob_node_for_taskgroup(ti): + _create_calcjob_node_from_taskgroup(ti, workchain_node, dag_run) + + logger.info(f"Finalized WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + + +# Airflow Listener Plugin +class AiiDATaskGroupIntegrationListener: + """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" @hookimpl def on_dag_run_running(self, dag_run: DagRun, msg: str): """Called when a DAG run enters the running state.""" - logger.info(f"[CLASS LISTENER] DAG run started: {dag_run.dag_id}/{dag_run.run_id}") - _store_dagrun_event(dag_run, 'running') + logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") + + if _should_integrate_dag_with_aiida(dag_run): + logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") + try: + _create_workchain_node_with_inputs(dag_run) + except Exception as e: + logger.error( + f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True + ) @hookimpl def on_dag_run_success(self, dag_run: DagRun, msg: str): """Called when a DAG run completes successfully.""" - logger.info(f"[CLASS LISTENER] DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") - _store_dagrun_event(dag_run, 'success') + logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") + + if _should_integrate_dag_with_aiida(dag_run): + logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") + try: + _finalize_workchain_node_with_outputs(dag_run) + except Exception as e: + logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) @hookimpl def on_dag_run_failed(self, dag_run: DagRun, msg: str): """Called when a DAG run fails.""" - logger.info(f"[CLASS LISTENER] DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") - _store_dagrun_event(dag_run, 'failed') + logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") + # Optionally store failed runs in AiiDA with appropriate exit status + # Create listener instance -dag_run_listener = DagRunListener() +aiida_taskgroup_listener = AiiDATaskGroupIntegrationListener() + +# Plugin registration class AiidaDagRunListener(AirflowPlugin): name = "aiida_dag_run_listener" - listeners = [dag_run_listener] + listeners = [aiida_taskgroup_listener] diff --git a/src/airflow_provider_aiida/plugins/dagrun_listener_async.py.bak b/src/airflow_provider_aiida/plugins/dagrun_listener_async.py.bak deleted file mode 100644 index fdfb6d4..0000000 --- a/src/airflow_provider_aiida/plugins/dagrun_listener_async.py.bak +++ /dev/null @@ -1,449 +0,0 @@ -import logging -from airflow.models import DagRun, TaskInstance -from airflow.plugins_manager import AirflowPlugin -from airflow.sdk.definitions.param import Param -from airflow.models import Param as ModelsParam -from airflow.listeners import hookimpl -from aiida import load_profile, orm -from aiida.common.links import LinkType -from pathlib import Path -from typing import Any, Dict, Optional -import json - -load_profile() - -logger = logging.getLogger(__name__) - - -def _param_to_python(param) -> Any: - """ - Convert an Airflow Param object to a Python native value. - - Args: - param: Airflow Param object or any other value - - Returns: - Python native value (int, float, bool, str, dict, list, etc.) - """ - - # Check if it's a Param object - if not isinstance(param, (Param, ModelsParam)): - return param - - # Get the actual value - actual_value = param.value - - # Get schema type if available - schema = getattr(param, "schema", {}) - param_type = schema.get("type", None) - - # Convert based on schema type - if param_type == "integer": - try: - return int(actual_value) - except (ValueError, TypeError): - logger.warning(f"Could not convert Param value '{actual_value}' to int") - return actual_value - - elif param_type == "number": - try: - return float(actual_value) - except (ValueError, TypeError): - logger.warning(f"Could not convert Param value '{actual_value}' to float") - return actual_value - - elif param_type == "boolean": - if isinstance(actual_value, bool): - return actual_value - # Handle string representations - if isinstance(actual_value, str): - return actual_value.lower() in ("true", "1", "yes", "on") - return bool(actual_value) - - elif param_type == "string": - return str(actual_value) - - elif param_type == "object": - # Should already be a dict - return actual_value if isinstance(actual_value, dict) else {} - - elif param_type == "array": - # Should already be a list - return actual_value if isinstance(actual_value, (list, tuple)) else [] - - else: - # No type specified or unknown type - return as-is - return actual_value - - -def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: - """ - Convert a Python value to the appropriate AiiDA Data node. - - Returns None if the value type is not supported or should be skipped. - """ - # First check if it's an Airflow Param and convert it - if isinstance(value, (Param, ModelsParam)): - value = _param_to_python(value) - - # Handle basic types (check bool BEFORE int, since bool is subclass of int) - if isinstance(value, bool): - return orm.Bool(value) - elif isinstance(value, int): - return orm.Int(value) - elif isinstance(value, float): - return orm.Float(value) - elif isinstance(value, str): - return orm.Str(value) - - # Handle collections - store as Dict or List nodes - elif isinstance(value, dict): - return orm.Dict(dict=value) - elif isinstance(value, (list, tuple)): - return orm.List(list=list(value)) - - # Handle Path objects - elif isinstance(value, Path): - return orm.Str(str(value)) - - # For complex objects, try JSON serialization - else: - try: - json_str = json.dumps(value) - return orm.Str(json_str) - except (TypeError, ValueError): - logger.warning( - f"Could not convert value of type {type(value)} to AiiDA node" - ) - return None - - -def _store_params_as_aiida_inputs( - node: orm.Node, params: Dict[str, Any], prefix: str = "" -) -> None: - """ - Store parameters as AiiDA data nodes and link them as inputs. - - Args: - node: The AiiDA node to link inputs to - params: Dictionary of parameters to store - prefix: Optional prefix for link labels (e.g., 'dag_params', 'conf') - """ - for key, value in params.items(): - # Create link label with optional prefix - link_label = f"{prefix}_{key}" if prefix else key - - # Skip None values - if value is None: - continue - - # Convert to AiiDA data node - aiida_data = _convert_to_aiida_data(value) - - if aiida_data is not None: - try: - # Store the data node first - aiida_data.store() - # Then add the link - node.base.links.add_incoming( - aiida_data, link_type=LinkType.INPUT_CALC, link_label=link_label - ) - except ValueError as e: - # Link already exists or other constraint violation - logger.debug(f"Could not link {link_label}: {e}") - - -def should_create_calcjob_node(task_instance: TaskInstance) -> bool: - """ - Determine if a task instance should be converted to a CalcJobNode. - - This is where you can implement your logic for identifying which tasks - should be stored in AiiDA. Options: - 1. Check operator type - 2. Check task_id pattern - 3. Check for specific XCom keys - 4. Use task tags/metadata - """ - # Option 1: Check operator type - if "CalcJobTaskOperator" in task_instance.operator: - return True - - # Option 2: Check task_id pattern - if "calcjob" in task_instance.task_id.lower(): - return True - - # Option 3: Check for marker in task metadata - task = task_instance.task - if hasattr(task, "params") and task.params.get("aiida_store", False): - return True - - return False - - -def _store_task_inputs(node: orm.CalcJobNode, task_instance: TaskInstance) -> None: - """ - Store all inputs for a task: params and conf. - - Args: - node: The CalcJobNode to link inputs to - task_instance: The Airflow task instance - """ - # import ipdb; ipdb.set_trace() - # Store task params (static parameters defined in DAG) - if hasattr(task_instance.task, "params") and task_instance.task.params: - _store_params_as_aiida_inputs( - node, task_instance.task.params, prefix="task_param" - ) - - # Store DAG run conf (dynamic parameters from trigger) - if task_instance.dag_run and task_instance.dag_run.conf: - _store_params_as_aiida_inputs(node, task_instance.dag_run.conf, prefix="conf") - - -def _store_task_outputs(node: orm.CalcJobNode, task_instance: TaskInstance) -> None: - """ - Store all outputs from a task: XCom values. - - Args: - node: The CalcJobNode to link outputs to - task_instance: The Airflow task instance - """ - # TODO: continue here, how to get outputs, XComs? - import ipdb - - ipdb.set_trace() - try: - # Get all XCom keys for this task - xcom_data = task_instance.xcom_pull(task_ids=task_instance.task_id, key=None) - - if not xcom_data: - return - - for key, value in xcom_data.items(): - # Skip return_value or handle it separately if needed - if key == "return_value": - continue - - # Convert to AiiDA data node - aiida_data = _convert_to_aiida_data(value) - if aiida_data: - aiida_data.store() - # Link as output (CREATE link) - aiida_data.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label=f"xcom_{key}" - ) - - except Exception as e: - logger.warning( - f"Could not retrieve XCom outputs for task {task_instance.task_id}: {e}" - ) - - -def _create_calcjob_node_from_task( - task_instance: TaskInstance, parent_workchain_node: orm.WorkChainNode -) -> orm.CalcJobNode: - """ - Create an AiiDA CalcJobNode from an Airflow task instance. - """ - node = orm.CalcJobNode() - node.label = f"airflow_calcjob_{task_instance.task_id}" - node.description = f"CalcJob from Airflow task {task_instance.task_id}" - - # Store Airflow metadata in extras - node.base.extras.set("airflow_dag_id", task_instance.dag_id) - node.base.extras.set("airflow_run_id", task_instance.run_id) - node.base.extras.set("airflow_task_id", task_instance.task_id) - - # Set process metadata - node.set_process_type(f"airflow.{task_instance.operator}") - node.set_process_state("finished") - node.set_exit_status(0 if task_instance.state == "success" else 1) - - # Link to parent WorkChainNode (before storing) - if parent_workchain_node: - node.base.links.add_incoming( - parent_workchain_node, - link_type=LinkType.CALL_CALC, - link_label=task_instance.task_id, - ) - - # Add inputs BEFORE storing the node - _store_task_inputs(node, task_instance) - - # Now store the node (inputs are locked in) - node.store() - - # Outputs can be added after storing - # import ipdb; ipdb.set_trace() - _store_task_outputs(node, task_instance) - - logger.info(f"Created CalcJobNode {node.pk} for task {task_instance.task_id}") - return node - - -def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: - """Check if this DAG should be stored in AiiDA""" - dag_tags = getattr(dag_run.dag, "tags", []) - return "aiida" in dag_tags - - -def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: - """ - Create a WorkChainNode from a running Airflow DAG and store its inputs. - - Returns: - The created and stored WorkChainNode - """ - workchain_node = orm.WorkChainNode() - workchain_node.label = f"airflow_dag_{dag_run.dag_id}" - workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" - - workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) - workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) - - # Store ALL DAG parameters generically - dag_params = getattr(dag_run.dag, "params", {}) - if dag_params: - _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") - - # Store ALL DAG configuration generically - dag_conf = getattr(dag_run, "conf", {}) - if dag_conf: - _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") - - workchain_node.set_process_state("running") - workchain_node.store() - - logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") - return workchain_node - - -def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: - """ - Find the WorkChainNode for a completed DAG run and add outputs (CalcJobNodes). - If the WorkChainNode doesn't exist yet (because on_dag_run_running wasn't called), - create it first. - """ - from aiida.orm import QueryBuilder - - # Try to find the WorkChainNode created in on_dag_run_running - qb = QueryBuilder() - qb.append( - orm.WorkChainNode, - filters={"extras.airflow_run_id": dag_run.run_id}, - tag="workchain", - ) - results = qb.all() - - if not results: - # WorkChainNode doesn't exist yet - create it now with inputs - logger.warning( - f"WorkChainNode not found for run_id {dag_run.run_id}. " - f"Creating it now (on_dag_run_running may not have been called)." - ) - workchain_node = _create_workchain_node_with_inputs(dag_run) - else: - workchain_node = results[0][0] - - # Update process state to finished - workchain_node.set_process_state("finished") - workchain_node.set_exit_status(0) - - # Process each task in the DAG to add CalcJobNodes with outputs - task_instances = dag_run.get_task_instances() - for ti in task_instances: - if ti.state == "success" and should_create_calcjob_node(ti): - _create_calcjob_node_from_task(ti, workchain_node) - - logger.info(f"Finalized WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") - - -# def _create_workchain_node_from_dag(dag_run: DagRun) -> None: -# """ -# Create a WorkChainNode from a successful Airflow DAG run. -# -# Now completely generic - stores all params and conf without assumptions. -# """ -# # Create the WorkChainNode for the entire DAG -# workchain_node = orm.WorkChainNode() -# workchain_node.label = f"airflow_dag_{dag_run.dag_id}" -# workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" -# -# workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) -# workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) -# -# # Store ALL DAG parameters generically -# dag_params = getattr(dag_run.dag, "params", {}) -# if dag_params: -# _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") -# -# # Store ALL DAG configuration generically -# dag_conf = getattr(dag_run, "conf", {}) -# if dag_conf: -# _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") -# -# workchain_node.set_process_state("finished") -# workchain_node.set_exit_status(0) -# workchain_node.store() -# -# # Process each task in the DAG -# task_instances = dag_run.get_task_instances() -# -# for ti in task_instances: -# # Generic detection: process any successful task -# # You can add filters here if needed (e.g., by operator type) -# if ti.state == "success": -# # Check if this is a CalcJob-like task -# # You might want to add metadata to your tasks to identify them -# if should_create_calcjob_node(ti): -# _create_calcjob_node_from_task(ti, workchain_node) -# -# logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") -# - - -# Airflow Listener Plugin -class AiiDAIntegrationListener: - """Listener that integrates Airflow DAG runs with AiiDA provenance""" - - @hookimpl - def on_dag_run_running(self, dag_run: DagRun, msg: str): - """Called when a DAG run enters the running state.""" - logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") - - if _should_integrate_dag_with_aiida(dag_run): - logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") - try: - _create_workchain_node_with_inputs(dag_run) - except Exception as e: - logger.error( - f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True - ) - - @hookimpl - def on_dag_run_success(self, dag_run: DagRun, msg: str): - """Called when a DAG run completes successfully.""" - logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") - - if _should_integrate_dag_with_aiida(dag_run): - logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") - try: - _finalize_workchain_node_with_outputs(dag_run) - except Exception as e: - logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) - - @hookimpl - def on_dag_run_failed(self, dag_run: DagRun, msg: str): - """Called when a DAG run fails.""" - logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") - # Optionally store failed runs in AiiDA with appropriate exit status - - -# Create listener instance -aiida_listener = AiiDAIntegrationListener() - - -# Plugin registration -class AiiDAIntegrationPlugin(AirflowPlugin): - name = "aiida_integration_plugin" - listeners = [aiida_listener] diff --git a/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py.bak b/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py.bak deleted file mode 100644 index 23e5b5b..0000000 --- a/src/airflow_provider_aiida/plugins/dagrun_listener_taskgroup.py.bak +++ /dev/null @@ -1,560 +0,0 @@ -import logging -from airflow.models import DagRun, TaskInstance -from airflow.plugins_manager import AirflowPlugin -from airflow.sdk.definitions.param import Param -from airflow.models import Param as ModelsParam -from airflow.listeners import hookimpl -from aiida import load_profile, orm -from aiida.common.links import LinkType -from pathlib import Path -from typing import Any, Dict, Optional -import json -import sys -sys.path.append('/home/geiger_j/aiida_projects/aiida-airflow/git-repos/airflow-prototype/dags/') -from calcjob_inheritance import CalcJobTaskGroup - -load_profile() - -logger = logging.getLogger(__name__) - - -def _param_to_python(param) -> Any: - """ - Convert an Airflow Param object to a Python native value. - - Args: - param: Airflow Param object or any other value - - Returns: - Python native value (int, float, bool, str, dict, list, etc.) - """ - - # Check if it's a Param object - if not isinstance(param, (Param, ModelsParam)): - return param - - # Get the actual value - actual_value = param.value - - # Get schema type if available - schema = getattr(param, "schema", {}) - param_type = schema.get("type", None) - - # Convert based on schema type - if param_type == "integer": - try: - return int(actual_value) - except (ValueError, TypeError): - logger.warning(f"Could not convert Param value '{actual_value}' to int") - return actual_value - - elif param_type == "number": - try: - return float(actual_value) - except (ValueError, TypeError): - logger.warning(f"Could not convert Param value '{actual_value}' to float") - return actual_value - - elif param_type == "boolean": - if isinstance(actual_value, bool): - return actual_value - # Handle string representations - if isinstance(actual_value, str): - return actual_value.lower() in ("true", "1", "yes", "on") - return bool(actual_value) - - elif param_type == "string": - return str(actual_value) - - elif param_type == "object": - # Should already be a dict - return actual_value if isinstance(actual_value, dict) else {} - - elif param_type == "array": - # Should already be a list - return actual_value if isinstance(actual_value, (list, tuple)) else [] - - else: - # No type specified or unknown type - return as-is - return actual_value - - -def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: - """ - Convert a Python value to the appropriate AiiDA Data node. - - Returns None if the value type is not supported or should be skipped. - """ - # First check if it's an Airflow Param and convert it - if isinstance(value, (Param, ModelsParam)): - value = _param_to_python(value) - - # Handle basic types (check bool BEFORE int, since bool is subclass of int) - if isinstance(value, bool): - return orm.Bool(value) - elif isinstance(value, int): - return orm.Int(value) - elif isinstance(value, float): - return orm.Float(value) - elif isinstance(value, str): - return orm.Str(value) - - # Handle collections - store as Dict or List nodes - elif isinstance(value, dict): - return orm.Dict(dict=value) - elif isinstance(value, (list, tuple)): - return orm.List(list=list(value)) - - # Handle Path objects - elif isinstance(value, Path): - return orm.Str(str(value)) - - # For complex objects, try JSON serialization - else: - try: - json_str = json.dumps(value) - return orm.Str(json_str) - except (TypeError, ValueError): - logger.warning( - f"Could not convert value of type {type(value)} to AiiDA node" - ) - return None - - -def _store_params_as_aiida_inputs( - node: orm.Node, params: Dict[str, Any], prefix: str = "" -) -> None: - """ - Store parameters as AiiDA data nodes and link them as inputs. - - Args: - node: The AiiDA node to link inputs to - params: Dictionary of parameters to store - prefix: Optional prefix for link labels (e.g., 'dag_params', 'conf') - """ - for key, value in params.items(): - # Create link label with optional prefix - link_label = f"{prefix}_{key}" if prefix else key - - # Skip None values - if value is None: - continue - - # Convert to AiiDA data node - aiida_data = _convert_to_aiida_data(value) - if isinstance(node, orm.WorkflowNode): - link_type = LinkType.INPUT_WORK - elif isinstance(node, orm.CalculationNode): - link_type = LinkType.INPUT_CALC - - if aiida_data is not None: - try: - # Store the data node first - aiida_data.store() - # Then add the link - node.base.links.add_incoming( - aiida_data, link_type=link_type, link_label=link_label - ) - except ValueError as e: - # Link already exists or other constraint violation - logger.debug(f"Could not link {link_label}: {e}") - - -def should_create_calcjob_node_for_taskgroup(task_instance: TaskInstance) -> bool: - """ - Determine if a task instance is part of a CalcJobTaskGroup. - - This checks if the task is the "parse" task of a CalcJobTaskGroup, - which signals completion of the entire group. - - Args: - task_instance: Airflow task instance - - Returns: - bool: True if this is a parse task from a CalcJobTaskGroup - """ - # Check if task_id indicates it's a parse task in a task group - if ".parse" in task_instance.task_id: - # Verify parent group exists and has the expected structure - group_id = task_instance.task_id.rsplit(".parse", 1)[0] - - # Check if this is likely a CalcJobTaskGroup by looking for sibling tasks - dag_run = task_instance.dag_run - if dag_run: - task_instances = dag_run.get_task_instances() - # Look for the prepare task in the same group - for ti in task_instances: - if ti.task_id == f"{group_id}.prepare": - return True - - return False - - -def _get_taskgroup_id_from_parse_task(task_instance: TaskInstance) -> str: - """Extract the task group ID from a parse task's task_id""" - return task_instance.task_id.rsplit(".parse", 1)[0] - - -def _store_taskgroup_inputs( - node: orm.CalcJobNode, task_instance: TaskInstance, dag_run: DagRun -) -> None: - """ - Store all inputs for a CalcJobTaskGroup. - - Inputs come from: - 1. The prepare task's XCom outputs (to_upload_files, submission_script, to_receive_files) - 2. The CalcJobTaskGroup instance's parameters (x, y, sleep, etc.) - 3. DAG-level params and conf - - Args: - node: The CalcJobNode to link inputs to - task_instance: The parse task instance (end of the group) - dag_run: The DAG run containing the task - """ - group_id = _get_taskgroup_id_from_parse_task(task_instance) - prepare_task_id = f"{group_id}.prepare" - - # Get the prepare task instance to access its XCom data - prepare_ti = None - for ti in dag_run.get_task_instances(): - if ti.task_id == prepare_task_id: - prepare_ti = ti - break - - if not prepare_ti: - logger.warning(f"Could not find prepare task {prepare_task_id}") - return - - # Store prepare task outputs as inputs to the CalcJob - try: - to_upload_files = task_instance.xcom_pull( - task_ids=prepare_task_id, key="to_upload_files" - ) - if to_upload_files: - aiida_data = _convert_to_aiida_data(to_upload_files) - if aiida_data: - aiida_data.store() - node.base.links.add_incoming( - aiida_data, - link_type=LinkType.INPUT_CALC, - link_label="to_upload_files", - ) - except Exception as e: - logger.debug(f"Could not store to_upload_files: {e}") - - try: - submission_script = task_instance.xcom_pull( - task_ids=prepare_task_id, key="submission_script" - ) - if submission_script: - aiida_data = _convert_to_aiida_data(submission_script) - if aiida_data: - aiida_data.store() - node.base.links.add_incoming( - aiida_data, - link_type=LinkType.INPUT_CALC, - link_label="submission_script", - ) - except Exception as e: - logger.debug(f"Could not store submission_script: {e}") - - try: - to_receive_files = task_instance.xcom_pull( - task_ids=prepare_task_id, key="to_receive_files" - ) - if to_receive_files: - aiida_data = _convert_to_aiida_data(to_receive_files) - if aiida_data: - aiida_data.store() - node.base.links.add_incoming( - aiida_data, - link_type=LinkType.INPUT_CALC, - link_label="to_receive_files", - ) - except Exception as e: - logger.debug(f"Could not store to_receive_files: {e}") - - # Store DAG-level params and conf - if dag_run.conf: - _store_params_as_aiida_inputs(node, dag_run.conf, prefix="conf") - - dag_params = getattr(dag_run.dag, "params", {}) - if dag_params: - _store_params_as_aiida_inputs(node, dag_params, prefix="dag_param") - - -def _store_taskgroup_outputs( - node: orm.CalcJobNode, task_instance: TaskInstance -) -> None: - """ - Store all outputs from a CalcJobTaskGroup. - - Outputs come from the parse task's XCom data (final_result). - - Args: - node: The CalcJobNode to link outputs to - task_instance: The parse task instance - """ - # NOTE: Aren't all the constructs we create taskgroups? How to differentiate between DAG and calcjob - try: - # Get the final_result from the parse task - final_result = task_instance.xcom_pull( - task_ids=task_instance.task_id, key="final_result" - ) - - if final_result: - # Handle tuple format (exit_status, results) from AddJobTaskGroup - if isinstance(final_result, tuple) and len(final_result) == 2: - exit_status, results = final_result - - # Store exit status - exit_status_node = orm.Int(exit_status) - exit_status_node.store() - exit_status_node.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label="exit_status" - ) - - # Store results dict - if results: - for key, value in results.items(): - aiida_data = _convert_to_aiida_data(value) - if aiida_data: - aiida_data.store() - aiida_data.base.links.add_incoming( - node, - link_type=LinkType.CREATE, - link_label=f"result_{key}", - ) - - # Handle dict format from MultiplyJobTaskGroup - elif isinstance(final_result, dict): - for key, value in final_result.items(): - aiida_data = _convert_to_aiida_data(value) - if aiida_data: - aiida_data.store() - aiida_data.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label=f"result_{key}" - ) - - # Handle other formats - else: - aiida_data = _convert_to_aiida_data(final_result) - if aiida_data: - aiida_data.store() - aiida_data.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label="final_result" - ) - - except Exception as e: - logger.warning( - f"Could not retrieve outputs for task {task_instance.task_id}: {e}" - ) - - -def _create_calcjob_node_from_taskgroup( - task_instance: TaskInstance, - parent_workchain_node: orm.WorkChainNode, - dag_run: DagRun, -) -> orm.CalcJobNode: - """ - Create an AiiDA CalcJobNode from a CalcJobTaskGroup (represented by its parse task). - - Args: - task_instance: The parse task instance (end of the TaskGroup) - parent_workchain_node: The parent WorkChainNode for the DAG - dag_run: The DAG run - - Returns: - The created and stored CalcJobNode - """ - # import ipdb; ipdb.set_trace() - # NOTE: locals() - # {'dag_run': , - # 'ipdb': , - # 'parent_workchain_node': , - # 'task_instance': } - - # NOTE: Should one pass a `task_instance` here, or shouldn't it be a taskgroup. - # possibly apply this function to every task group, and have special handling only when it is calcjob. if not, then store things in the "workchain way" - - # if isinstance(CalcJobTaskGroup): - - group_id = _get_taskgroup_id_from_parse_task(task_instance) - - node = orm.CalcJobNode() - node.label = f"airflow_calcjob_group_{group_id}" - node.description = f"CalcJob from Airflow TaskGroup {group_id}" - - # Store Airflow metadata in extras - node.base.extras.set("airflow_dag_id", task_instance.dag_id) - node.base.extras.set("airflow_run_id", task_instance.run_id) - node.base.extras.set("airflow_task_group_id", group_id) - - # Set process metadata - node.set_process_type(f"airflow.CalcJobTaskGroup") - node.set_process_state("finished") - - # Determine exit status from parse task result - exit_status = 0 - try: - final_result = task_instance.xcom_pull( - task_ids=task_instance.task_id, key="final_result" - ) - if isinstance(final_result, tuple) and len(final_result) == 2: - exit_status = final_result[0] - except Exception: - pass - - node.set_exit_status(exit_status if task_instance.state == "success" else 1) - - # Link to parent WorkChainNode (before storing) - if parent_workchain_node: - node.base.links.add_incoming( - parent_workchain_node, - link_type=LinkType.CALL_CALC, - link_label=group_id, - ) - - # Add inputs BEFORE storing the node - # TODO: Computer is not an input, but an attribute of the calcjob, or, rather, `metadata.computer` - _store_taskgroup_inputs(node, task_instance, dag_run) - - # Now store the node (inputs are locked in) - node.store() - # NOTE: final_result -> {'__classname__': 'builtins.tuple', '__version__': 1, '__data__': [0, {'result.out': 12}]} - # TODO: - - # Outputs can be added after storing - _store_taskgroup_outputs(node, task_instance) - - logger.info(f"Created CalcJobNode {node.pk} for TaskGroup {group_id}") - return node - - -def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: - """Check if this DAG should be stored in AiiDA""" - dag_tags = getattr(dag_run.dag, "tags", []) - # Look for tags that indicate this is a CalcJob workflow - return any(tag in dag_tags for tag in ["aiida", "calcjob", "taskgroup"]) - - -def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: - """ - Create a WorkChainNode from a running Airflow DAG and store its inputs. - - Returns: - The created and stored WorkChainNode - """ - workchain_node = orm.WorkChainNode() - workchain_node.label = f"airflow_dag_{dag_run.dag_id}" - workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" - - workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) - workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) - - # Store ALL DAG parameters generically - dag_params = getattr(dag_run.dag, "params", {}) - if dag_params: - # TODO: These are being set as inputs of the workflow, even though they are inputs of the calculations, and are just passed through the dag - # {'machine': 'localhost', 'local_workdir': '/home/geiger_j/airflow/storage/local_workdir', 'remote_workdir': '/home/geiger_j/airflow/storage/remote_workdir'} - # import ipdb; ipdb.set_trace() - _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") - - # Store ALL DAG configuration generically - dag_conf = getattr(dag_run, "conf", {}) - if dag_conf: - _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") - - workchain_node.set_process_state("running") - workchain_node.store() - - logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") - return workchain_node - - -def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: - """ - Find the WorkChainNode for a completed DAG run and add outputs (CalcJobNodes from TaskGroups). - If the WorkChainNode doesn't exist yet, create it first. - """ - # NOTE: Why do i need to query for it? it should be directly accessible, no? -> it's bc i create it in the other function? - from aiida.orm import QueryBuilder - - # Try to find the WorkChainNode created in on_dag_run_running - qb = QueryBuilder() - qb.append( - orm.WorkChainNode, - filters={"extras.airflow_run_id": dag_run.run_id}, - tag="workchain", - ) - results = qb.all() - - if not results: - # WorkChainNode doesn't exist yet - create it now with inputs - logger.warning( - f"WorkChainNode not found for run_id {dag_run.run_id}. " - f"Creating it now (on_dag_run_running may not have been called)." - ) - workchain_node = _create_workchain_node_with_inputs(dag_run) - else: - workchain_node = results[0][0] - - # Update process state to finished - workchain_node.set_process_state("finished") - workchain_node.set_exit_status(0) - - # Process each task in the DAG to find CalcJobTaskGroup parse tasks - task_instances = dag_run.get_task_instances() - for ti in task_instances: - if ti.state == "success" and should_create_calcjob_node_for_taskgroup(ti): - _create_calcjob_node_from_taskgroup(ti, workchain_node, dag_run) - - logger.info(f"Finalized WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") - - -# Airflow Listener Plugin -class AiiDATaskGroupIntegrationListener: - """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" - - @hookimpl - def on_dag_run_running(self, dag_run: DagRun, msg: str): - """Called when a DAG run enters the running state.""" - logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") - - if _should_integrate_dag_with_aiida(dag_run): - logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") - try: - _create_workchain_node_with_inputs(dag_run) - except Exception as e: - logger.error( - f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True - ) - - @hookimpl - def on_dag_run_success(self, dag_run: DagRun, msg: str): - """Called when a DAG run completes successfully.""" - logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") - - if _should_integrate_dag_with_aiida(dag_run): - logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") - try: - _finalize_workchain_node_with_outputs(dag_run) - except Exception as e: - logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) - - @hookimpl - def on_dag_run_failed(self, dag_run: DagRun, msg: str): - """Called when a DAG run fails.""" - logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") - # Optionally store failed runs in AiiDA with appropriate exit status - - -# Create listener instance -aiida_taskgroup_listener = AiiDATaskGroupIntegrationListener() - - -# Plugin registration -class AiiDATaskGroupIntegrationPlugin(AirflowPlugin): - name = "aiida_taskgroup_integration_plugin" - listeners = [aiida_taskgroup_listener] - From 79262854dabb445a2156782e76452a4d982442b0 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 2 Oct 2025 17:24:41 +0200 Subject: [PATCH 04/11] add run script for arithmetic add --- run_arithmetic_dag.py | 27 +++++++++++++++++++ .../plugins/aiida_dag_run_listener.py | 5 ++-- 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 run_arithmetic_dag.py diff --git a/run_arithmetic_dag.py b/run_arithmetic_dag.py new file mode 100644 index 0000000..fbc5102 --- /dev/null +++ b/run_arithmetic_dag.py @@ -0,0 +1,27 @@ +import subprocess +from pathlib import Path +import json +import os + +# Set AIRFLOW__CORE__DAGS_FOLDER to include example_dags +dag_folder = str(Path(__file__).parent / 'src' / 'airflow_provider_aiida' / 'example_dags') +os.environ['AIRFLOW__CORE__DAGS_FOLDER'] = dag_folder + +# Create directories +Path('/tmp/airflow/local_workdir').mkdir(parents=True, exist_ok=True) +Path('/tmp/airflow/remote_workdir').mkdir(parents=True, exist_ok=True) + +# Configuration +conf = { + 'machine': 'localhost', + 'local_workdir': '/tmp/airflow/local_workdir', + 'remote_workdir': '/tmp/airflow/remote_workdir', + 'add_x': 10, + 'add_y': 5, + 'multiply_x': 7, + 'multiply_y': 3, +} + +# Run DAG using CLI +cmd = ['airflow', 'dags', 'test', 'arithmetic_add_multiply', '--conf', json.dumps(conf)] +subprocess.run(cmd) diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index f68caaf..0111907 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -13,8 +13,8 @@ import json # Add dags directory to path for CalcJobTaskGroup import -sys.path.append('/home/geiger_j/aiida_projects/aiida-airflow/git-repos/airflow-prototype/dags/') -from calcjob_inheritance import CalcJobTaskGroup +# sys.path.append('/home/geiger_j/aiida_projects/aiida-airflow/git-repos/airflow-prototype/dags/') +# from calcjob_inheritance import CalcJobTaskGroup load_profile() @@ -453,6 +453,7 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: workchain_node.store() logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + breakpoint() return workchain_node From 0dafabc8d732a7c4e438e88c1c9dc3b16dc41c3e Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 2 Oct 2025 17:47:39 +0200 Subject: [PATCH 05/11] wip; run script --- run_arithmetic_dag.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/run_arithmetic_dag.py b/run_arithmetic_dag.py index fbc5102..c1e5e57 100644 --- a/run_arithmetic_dag.py +++ b/run_arithmetic_dag.py @@ -1,27 +1,31 @@ -import subprocess from pathlib import Path -import json import os +from airflow.models import DagBag +from airflow.utils.state import DagRunState +from datetime import datetime # Set AIRFLOW__CORE__DAGS_FOLDER to include example_dags -dag_folder = str(Path(__file__).parent / 'src' / 'airflow_provider_aiida' / 'example_dags') -os.environ['AIRFLOW__CORE__DAGS_FOLDER'] = dag_folder +dag_folder = str( + Path(__file__).parent / "src" / "airflow_provider_aiida" / "example_dags" +) +os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = dag_folder # Create directories -Path('/tmp/airflow/local_workdir').mkdir(parents=True, exist_ok=True) -Path('/tmp/airflow/remote_workdir').mkdir(parents=True, exist_ok=True) +Path("/tmp/airflow/local_workdir").mkdir(parents=True, exist_ok=True) +Path("/tmp/airflow/remote_workdir").mkdir(parents=True, exist_ok=True) # Configuration conf = { - 'machine': 'localhost', - 'local_workdir': '/tmp/airflow/local_workdir', - 'remote_workdir': '/tmp/airflow/remote_workdir', - 'add_x': 10, - 'add_y': 5, - 'multiply_x': 7, - 'multiply_y': 3, + "machine": "localhost", + "local_workdir": "/tmp/airflow/local_workdir", + "remote_workdir": "/tmp/airflow/remote_workdir", + "add_x": 10, + "add_y": 5, + "multiply_x": 7, + "multiply_y": 3, } -# Run DAG using CLI -cmd = ['airflow', 'dags', 'test', 'arithmetic_add_multiply', '--conf', json.dumps(conf)] -subprocess.run(cmd) +# Run DAG using Python API +dagbag = DagBag(dag_folder=dag_folder) +dag = dagbag.get_dag("arithmetic_add_multiply") +dag.test(run_conf=conf) From 0a9507a6f7c4c576e38295f46dfb2ac5bbfd5445 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Fri, 3 Oct 2025 14:06:36 +0200 Subject: [PATCH 06/11] script runs through now --- .../example_dags/arithmetic_add.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/airflow_provider_aiida/example_dags/arithmetic_add.py b/src/airflow_provider_aiida/example_dags/arithmetic_add.py index a7f045b..153d848 100644 --- a/src/airflow_provider_aiida/example_dags/arithmetic_add.py +++ b/src/airflow_provider_aiida/example_dags/arithmetic_add.py @@ -20,11 +20,20 @@ def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workd def prepare(self, **context) -> Dict[str, Any]: """Prepare addition job inputs""" + # Resolve template variables from params + from airflow.models import TaskInstance + ti: TaskInstance = context['task_instance'] + params = context['params'] + + x = params['add_x'] + y = params['add_y'] + sleep = 3 # or get from params if needed + to_upload_files = {} submission_script = f""" -sleep {self.sleep} -echo "$(({self.x}+{self.y}))" > result.out +sleep {sleep} +echo "$(({x}+{y}))" > result.out """ to_receive_files = {"result.out": "addition_result.txt"} @@ -84,12 +93,21 @@ def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workd def prepare(self, **context) -> Dict[str, Any]: """Prepare multiplication job inputs""" + # Resolve template variables from params + from airflow.models import TaskInstance + ti: TaskInstance = context['task_instance'] + params = context['params'] + + x = params['multiply_x'] + y = params['multiply_y'] + sleep = 2 # or get from params if needed + to_upload_files = {} submission_script = f""" -sleep {self.sleep} -echo "$(({self.x}*{self.y}))" > multiply_result.out -echo "Operation: {self.x} * {self.y}" > operation.log +sleep {sleep} +echo "$(({x}*{y}))" > multiply_result.out +echo "Operation: {x} * {y}" > operation.log """ to_receive_files = { @@ -238,4 +256,4 @@ def combine_results(): # Direct usage - add_job and multiply_job ARE TaskGroups! combine_task = combine_results() - [add_job, multiply_job] >> combine_task \ No newline at end of file + [add_job, multiply_job] >> combine_task From 7b2071ee152950588d216998699a21c806635e9b Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Fri, 3 Oct 2025 17:19:10 +0200 Subject: [PATCH 07/11] wip --- .../example_dags/arithmetic_add.py | 26 ++- .../plugins/aiida_dag_run_listener.py | 176 ++++++------------ 2 files changed, 77 insertions(+), 125 deletions(-) diff --git a/src/airflow_provider_aiida/example_dags/arithmetic_add.py b/src/airflow_provider_aiida/example_dags/arithmetic_add.py index 153d848..811255e 100644 --- a/src/airflow_provider_aiida/example_dags/arithmetic_add.py +++ b/src/airflow_provider_aiida/example_dags/arithmetic_add.py @@ -11,6 +11,10 @@ class AddJobTaskGroup(CalcJobTaskGroup): """Addition job task group - directly IS a TaskGroup""" + # Define AiiDA input/output port names (like in aiida-core CalcJob.define()) + # AIIDA_INPUT_PORTS = ['x', 'y'] + # AIIDA_OUTPUT_PORTS = ['sum'] + def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workdir: str, x: int, y: int, sleep: int, **kwargs): self.x = x @@ -43,6 +47,9 @@ def prepare(self, **context) -> Dict[str, Any]: context['task_instance'].xcom_push(key='submission_script', value=submission_script) context['task_instance'].xcom_push(key='to_receive_files', value=to_receive_files) + # Push AiiDA inputs for provenance (matches AIIDA_INPUT_PORTS) + context['task_instance'].xcom_push(key='aiida_inputs', value={'x': x, 'y': y}) + return { "to_upload_files": to_upload_files, "submission_script": submission_script, @@ -68,7 +75,7 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]: continue result_content = file_path.read_text().strip() - print(f"Addition result ({self.x} + {self.y}): {result_content}") + print(f"Addition result: {result_content}") results[file_key] = int(result_content) except Exception as e: @@ -78,12 +85,21 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]: # Store both exit status and results in XCom final_result = (exit_status, results) context['task_instance'].xcom_push(key='final_result', value=final_result) + + # Push AiiDA outputs for provenance (matches AIIDA_OUTPUT_PORTS) + if 'result.out' in results: + context['task_instance'].xcom_push(key='aiida_outputs', value={'sum': results['result.out']}) + return final_result class MultiplyJobTaskGroup(CalcJobTaskGroup): """Multiplication job task group - directly IS a TaskGroup""" + # Define AiiDA input/output port names (like in aiida-core CalcJob.define()) + # AIIDA_INPUT_PORTS = ['x', 'y'] + # AIIDA_OUTPUT_PORTS = ['result'] + def __init__(self, group_id: str, machine: str, local_workdir: str, remote_workdir: str, x: int, y: int, sleep: int, **kwargs): self.x = x @@ -120,6 +136,9 @@ def prepare(self, **context) -> Dict[str, Any]: context['task_instance'].xcom_push(key='submission_script', value=submission_script) context['task_instance'].xcom_push(key='to_receive_files', value=to_receive_files) + # Push AiiDA inputs for provenance (matches AIIDA_INPUT_PORTS) + context['task_instance'].xcom_push(key='aiida_inputs', value={'x': x, 'y': y}) + return { "to_upload_files": to_upload_files, "submission_script": submission_script, @@ -164,6 +183,11 @@ def parse(self, local_workdir: str, **context) -> tuple[int, Dict[str, Any]]: # Store both exit status and results in XCom final_result = (exit_status, results) context['task_instance'].xcom_push(key='final_result', value=final_result) + + # Push AiiDA outputs for provenance (matches AIIDA_OUTPUT_PORTS) + if 'result' in results: + context['task_instance'].xcom_push(key='aiida_outputs', value={'result': results['result']}) + return final_result diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index 0111907..3422249 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -204,10 +204,8 @@ def _store_taskgroup_inputs( """ Store all inputs for a CalcJobTaskGroup. - Inputs come from: - 1. The prepare task's XCom outputs (to_upload_files, submission_script, to_receive_files) - 2. The CalcJobTaskGroup instance's parameters (x, y, sleep, etc.) - 3. DAG-level params and conf + Inputs should be explicitly stored by the prepare task in XCom with key 'aiida_inputs'. + This allows each TaskGroup to define its own input structure. Args: node: The CalcJobNode to link inputs to @@ -217,73 +215,22 @@ def _store_taskgroup_inputs( group_id = _get_taskgroup_id_from_parse_task(task_instance) prepare_task_id = f"{group_id}.prepare" - # Get the prepare task instance to access its XCom data - prepare_ti = None - for ti in dag_run.get_task_instances(): - if ti.task_id == prepare_task_id: - prepare_ti = ti - break - - if not prepare_ti: - logger.warning(f"Could not find prepare task {prepare_task_id}") - return - - # Store prepare task outputs as inputs to the CalcJob - try: - to_upload_files = task_instance.xcom_pull( - task_ids=prepare_task_id, key="to_upload_files" - ) - if to_upload_files: - aiida_data = _convert_to_aiida_data(to_upload_files) - if aiida_data: - aiida_data.store() - node.base.links.add_incoming( - aiida_data, - link_type=LinkType.INPUT_CALC, - link_label="to_upload_files", - ) - except Exception as e: - logger.debug(f"Could not store to_upload_files: {e}") - + # Try to get inputs explicitly defined by the prepare task try: - submission_script = task_instance.xcom_pull( - task_ids=prepare_task_id, key="submission_script" + aiida_inputs = task_instance.xcom_pull( + task_ids=prepare_task_id, key="aiida_inputs" ) - if submission_script: - aiida_data = _convert_to_aiida_data(submission_script) - if aiida_data: - aiida_data.store() - node.base.links.add_incoming( - aiida_data, - link_type=LinkType.INPUT_CALC, - link_label="submission_script", - ) + if aiida_inputs and isinstance(aiida_inputs, dict): + _store_params_as_aiida_inputs(node, aiida_inputs, prefix="") + return except Exception as e: - logger.debug(f"Could not store submission_script: {e}") + logger.debug(f"Could not retrieve aiida_inputs from prepare task: {e}") - try: - to_receive_files = task_instance.xcom_pull( - task_ids=prepare_task_id, key="to_receive_files" - ) - if to_receive_files: - aiida_data = _convert_to_aiida_data(to_receive_files) - if aiida_data: - aiida_data.store() - node.base.links.add_incoming( - aiida_data, - link_type=LinkType.INPUT_CALC, - link_label="to_receive_files", - ) - except Exception as e: - logger.debug(f"Could not store to_receive_files: {e}") - - # Store DAG-level params and conf - if dag_run.conf: - _store_params_as_aiida_inputs(node, dag_run.conf, prefix="conf") - - dag_params = getattr(dag_run.dag, "params", {}) - if dag_params: - _store_params_as_aiida_inputs(node, dag_params, prefix="dag_param") + # If no explicit inputs provided, log a warning + logger.warning( + f"No 'aiida_inputs' found in XCom for {prepare_task_id}. " + f"CalcJobTaskGroup should push a dict with key 'aiida_inputs' containing input data." + ) def _store_taskgroup_outputs( @@ -292,65 +239,37 @@ def _store_taskgroup_outputs( """ Store all outputs from a CalcJobTaskGroup. - Outputs come from the parse task's XCom data (final_result). + Outputs should be explicitly stored by the parse task in XCom with key 'aiida_outputs'. + This allows each TaskGroup to define its own output structure. Args: node: The CalcJobNode to link outputs to task_instance: The parse task instance """ try: - # Get the final_result from the parse task - final_result = task_instance.xcom_pull( - task_ids=task_instance.task_id, key="final_result" + # Try to get outputs explicitly defined by the parse task + aiida_outputs = task_instance.xcom_pull( + task_ids=task_instance.task_id, key="aiida_outputs" ) - if final_result: - # Handle tuple format (exit_status, results) from AddJobTaskGroup - if isinstance(final_result, tuple) and len(final_result) == 2: - exit_status, results = final_result - - # Store exit status - exit_status_node = orm.Int(exit_status) - exit_status_node.store() - exit_status_node.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label="exit_status" - ) - - # Store results dict - if results: - for key, value in results.items(): - aiida_data = _convert_to_aiida_data(value) - if aiida_data: - aiida_data.store() - aiida_data.base.links.add_incoming( - node, - link_type=LinkType.CREATE, - link_label=f"result_{key}", - ) - - # Handle dict format from MultiplyJobTaskGroup - elif isinstance(final_result, dict): - for key, value in final_result.items(): - aiida_data = _convert_to_aiida_data(value) - if aiida_data: - aiida_data.store() - aiida_data.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label=f"result_{key}" - ) - - # Handle other formats - else: - aiida_data = _convert_to_aiida_data(final_result) + if aiida_outputs and isinstance(aiida_outputs, dict): + for key, value in aiida_outputs.items(): + aiida_data = _convert_to_aiida_data(value) if aiida_data: aiida_data.store() aiida_data.base.links.add_incoming( - node, link_type=LinkType.CREATE, link_label="final_result" + node, link_type=LinkType.CREATE, link_label=key ) + return except Exception as e: - logger.warning( - f"Could not retrieve outputs for task {task_instance.task_id}: {e}" - ) + logger.debug(f"Could not retrieve aiida_outputs from parse task: {e}") + + # If no explicit outputs provided, log a warning + logger.warning( + f"No 'aiida_outputs' found in XCom for {task_instance.task_id}. " + f"CalcJobTaskGroup parse method should push a dict with key 'aiida_outputs' containing output data." + ) def _create_calcjob_node_from_taskgroup( @@ -372,7 +291,7 @@ def _create_calcjob_node_from_taskgroup( group_id = _get_taskgroup_id_from_parse_task(task_instance) node = orm.CalcJobNode() - node.label = f"airflow_calcjob_group_{group_id}" + node.label = group_id node.description = f"CalcJob from Airflow TaskGroup {group_id}" # Store Airflow metadata in extras @@ -380,8 +299,8 @@ def _create_calcjob_node_from_taskgroup( node.base.extras.set("airflow_run_id", task_instance.run_id) node.base.extras.set("airflow_task_group_id", group_id) - # Set process metadata - node.set_process_type(f"airflow.CalcJobTaskGroup") + # Set process type to the group ID + node.set_process_type(group_id) node.set_process_state("finished") # Determine exit status from parse task result @@ -433,27 +352,36 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: The created and stored WorkChainNode """ workchain_node = orm.WorkChainNode() - workchain_node.label = f"airflow_dag_{dag_run.dag_id}" + workchain_node.label = dag_run.dag_id workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) - # Store ALL DAG parameters generically - dag_params = getattr(dag_run.dag, "params", {}) - if dag_params: - _store_params_as_aiida_inputs(workchain_node, dag_params, prefix="dag_param") + # Set process type to the DAG ID + workchain_node.set_process_type(dag_run.dag_id) - # Store ALL DAG configuration generically + # Store DAG parameters with clean names (no prefixes) + # Use conf if available, otherwise use default params dag_conf = getattr(dag_run, "conf", {}) - if dag_conf: - _store_params_as_aiida_inputs(workchain_node, dag_conf, prefix="conf") + dag_params = getattr(dag_run.dag, "params", {}) + + # Prefer conf values (runtime overrides), fall back to default params + params_to_store = {} + for key, param in dag_params.items(): + # Get actual value from conf or use default + if dag_conf and key in dag_conf: + params_to_store[key] = dag_conf[key] + else: + params_to_store[key] = _param_to_python(param) + + # Store with clean names (no prefix) + _store_params_as_aiida_inputs(workchain_node, params_to_store, prefix="") workchain_node.set_process_state("running") workchain_node.store() logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") - breakpoint() return workchain_node From ad39d9d16cfb66a666435aceab96b4932e7b1f70 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 13 Oct 2025 11:51:35 +0200 Subject: [PATCH 08/11] wip --- run_arithmetic_dag.py | 18 +++-- .../plugins/aiida_dag_run_listener.py | 78 ++++++++++--------- 2 files changed, 55 insertions(+), 41 deletions(-) diff --git a/run_arithmetic_dag.py b/run_arithmetic_dag.py index c1e5e57..c811dc7 100644 --- a/run_arithmetic_dag.py +++ b/run_arithmetic_dag.py @@ -1,8 +1,5 @@ from pathlib import Path import os -from airflow.models import DagBag -from airflow.utils.state import DagRunState -from datetime import datetime # Set AIRFLOW__CORE__DAGS_FOLDER to include example_dags dag_folder = str( @@ -10,6 +7,9 @@ ) os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = dag_folder +# Import AFTER setting the environment variable +from airflow.models import DagBag + # Create directories Path("/tmp/airflow/local_workdir").mkdir(parents=True, exist_ok=True) Path("/tmp/airflow/remote_workdir").mkdir(parents=True, exist_ok=True) @@ -25,7 +25,13 @@ "multiply_y": 3, } -# Run DAG using Python API -dagbag = DagBag(dag_folder=dag_folder) +# Run DAG using test mode (bypasses serialization requirement) +dagbag = DagBag(dag_folder=dag_folder, include_examples=False) dag = dagbag.get_dag("arithmetic_add_multiply") -dag.test(run_conf=conf) + +# Use test mode with execution_date to avoid serialization issues +from datetime import datetime +dag.test( + run_conf=conf, + use_executor=False # Run tasks sequentially in the same process +) diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index 3422249..d2e6be3 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -124,7 +124,7 @@ def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: return None -def _store_params_as_aiida_inputs( +def _store_dag_inputs_in_aiida( node: orm.Node, params: Dict[str, Any], prefix: str = "" ) -> None: """ @@ -221,7 +221,7 @@ def _store_taskgroup_inputs( task_ids=prepare_task_id, key="aiida_inputs" ) if aiida_inputs and isinstance(aiida_inputs, dict): - _store_params_as_aiida_inputs(node, aiida_inputs, prefix="") + _store_dag_inputs_in_aiida(node, aiida_inputs, prefix="") return except Exception as e: logger.debug(f"Could not retrieve aiida_inputs from prepare task: {e}") @@ -290,18 +290,20 @@ def _create_calcjob_node_from_taskgroup( """ group_id = _get_taskgroup_id_from_parse_task(task_instance) - node = orm.CalcJobNode() - node.label = group_id - node.description = f"CalcJob from Airflow TaskGroup {group_id}" + cj_node: orm.CalcJobNode = orm.CalcJobNode() + cj_node.label = group_id + cj_node.description = f"CalcJob from Airflow TaskGroup {group_id}" # Store Airflow metadata in extras - node.base.extras.set("airflow_dag_id", task_instance.dag_id) - node.base.extras.set("airflow_run_id", task_instance.run_id) - node.base.extras.set("airflow_task_group_id", group_id) + cj_node.base.extras.set("airflow_dag_id", task_instance.dag_id) + cj_node.base.extras.set("airflow_run_id", task_instance.run_id) + cj_node.base.extras.set("airflow_task_group_id", group_id) # Set process type to the group ID - node.set_process_type(group_id) - node.set_process_state("finished") + cj_node.set_process_type(group_id) + cj_node.set_process_state("finished") + import ipdb; ipdb.set_trace() + cj_node.set_process_label('AirflowCalcJob') # Determine exit status from parse task result exit_status = 0 @@ -314,27 +316,27 @@ def _create_calcjob_node_from_taskgroup( except Exception: pass - node.set_exit_status(exit_status if task_instance.state == "success" else 1) + cj_node.set_exit_status(exit_status if task_instance.state == "success" else 1) # Link to parent WorkChainNode (before storing) if parent_workchain_node: - node.base.links.add_incoming( + cj_node.base.links.add_incoming( parent_workchain_node, link_type=LinkType.CALL_CALC, link_label=group_id, ) # Add inputs BEFORE storing the node - _store_taskgroup_inputs(node, task_instance, dag_run) + _store_taskgroup_inputs(cj_node, task_instance, dag_run) # Now store the node (inputs are locked in) - node.store() + cj_node.store() # Outputs can be added after storing - _store_taskgroup_outputs(node, task_instance) + _store_taskgroup_outputs(cj_node, task_instance) - logger.info(f"Created CalcJobNode {node.pk} for TaskGroup {group_id}") - return node + logger.info(f"Created CalcJobNode {cj_node.pk} for TaskGroup {group_id}") + return cj_node def _should_integrate_dag_with_aiida(dag_run: DagRun) -> bool: @@ -351,17 +353,19 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: Returns: The created and stored WorkChainNode """ - workchain_node = orm.WorkChainNode() - workchain_node.label = dag_run.dag_id - workchain_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" + wc_node: orm.WorkChainNode = orm.WorkChainNode() + wc_node.label = dag_run.dag_id + wc_node.description = f"Workflow from Airflow DAG {dag_run.dag_id}" - workchain_node.base.extras.set("airflow_dag_id", dag_run.dag_id) - workchain_node.base.extras.set("airflow_run_id", dag_run.run_id) + wc_node.base.extras.set("airflow_dag_id", dag_run.dag_id) + wc_node.base.extras.set("airflow_run_id", dag_run.run_id) # Set process type to the DAG ID - workchain_node.set_process_type(dag_run.dag_id) + wc_node.set_process_type(dag_run.dag_id) + import ipdb; ipdb.set_trace() + wc_node.set_process_label('AirflowWorkChain') - # Store DAG parameters with clean names (no prefixes) + # Store DAG parameters # Use conf if available, otherwise use default params dag_conf = getattr(dag_run, "conf", {}) dag_params = getattr(dag_run.dag, "params", {}) @@ -376,13 +380,13 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: params_to_store[key] = _param_to_python(param) # Store with clean names (no prefix) - _store_params_as_aiida_inputs(workchain_node, params_to_store, prefix="") + _store_dag_inputs_in_aiida(wc_node, params_to_store, prefix="") - workchain_node.set_process_state("running") - workchain_node.store() + wc_node.set_process_state("running") + wc_node.store() - logger.info(f"Created WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") - return workchain_node + logger.info(f"Created WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}") + return wc_node def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: @@ -407,32 +411,36 @@ def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: f"WorkChainNode not found for run_id {dag_run.run_id}. " f"Creating it now (on_dag_run_running may not have been called)." ) - workchain_node = _create_workchain_node_with_inputs(dag_run) + wc_node = _create_workchain_node_with_inputs(dag_run) else: - workchain_node = results[0][0] + wc_node = results[0][0] # Update process state to finished - workchain_node.set_process_state("finished") - workchain_node.set_exit_status(0) + wc_node.set_process_state("finished") + wc_node.set_exit_status(0) # Process each task in the DAG to find CalcJobTaskGroup parse tasks task_instances = dag_run.get_task_instances() for ti in task_instances: if ti.state == "success" and should_create_calcjob_node_for_taskgroup(ti): - _create_calcjob_node_from_taskgroup(ti, workchain_node, dag_run) + _create_calcjob_node_from_taskgroup(ti, wc_node, dag_run) - logger.info(f"Finalized WorkChainNode {workchain_node.pk} for DAG {dag_run.dag_id}") + logger.info(f"Finalized WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}") # Airflow Listener Plugin class AiiDATaskGroupIntegrationListener: """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" + # NOTE: Apparently this is never triggered when using `dag.test` + # Hence, the WorkChainNode is only ever created once the DAG has run through fully + # use instead `on_task_instance_[running,success]` triggers @hookimpl def on_dag_run_running(self, dag_run: DagRun, msg: str): """Called when a DAG run enters the running state.""" logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") + import ipdb; ipdb.set_trace() if _should_integrate_dag_with_aiida(dag_run): logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") try: From 31966dc2058134a4e3aeddd05240e53efb88c803 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 13 Oct 2025 12:49:14 +0200 Subject: [PATCH 09/11] wip; before using only one hookimpl --- run_arithmetic_dag.py | 9 +- .../plugins/aiida_dag_run_listener.py | 209 ++++++++++++++---- 2 files changed, 172 insertions(+), 46 deletions(-) diff --git a/run_arithmetic_dag.py b/run_arithmetic_dag.py index c811dc7..d59d498 100644 --- a/run_arithmetic_dag.py +++ b/run_arithmetic_dag.py @@ -30,8 +30,13 @@ dag = dagbag.get_dag("arithmetic_add_multiply") # Use test mode with execution_date to avoid serialization issues -from datetime import datetime + dag.test( run_conf=conf, - use_executor=False # Run tasks sequentially in the same process + # execution_date=datetime.now(), + use_executor=False, # Run tasks sequentially in the same process ) + +# Trigger DAG using API client (requires scheduler to be running) +# client: Client = Client() +# client.trigger_dag(dag_id="arithmetic_add_multiply", conf=conf) diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index d2e6be3..8a25ee2 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -302,8 +302,10 @@ def _create_calcjob_node_from_taskgroup( # Set process type to the group ID cj_node.set_process_type(group_id) cj_node.set_process_state("finished") - import ipdb; ipdb.set_trace() - cj_node.set_process_label('AirflowCalcJob') + import ipdb + + ipdb.set_trace() + cj_node.set_process_label("AirflowCalcJob") # Determine exit status from parse task result exit_status = 0 @@ -362,10 +364,12 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: # Set process type to the DAG ID wc_node.set_process_type(dag_run.dag_id) - import ipdb; ipdb.set_trace() - wc_node.set_process_label('AirflowWorkChain') + import ipdb + + ipdb.set_trace() + wc_node.set_process_label("AirflowWorkChain") - # Store DAG parameters + # Store DAG parameters # Use conf if available, otherwise use default params dag_conf = getattr(dag_run, "conf", {}) dag_params = getattr(dag_run.dag, "params", {}) @@ -429,45 +433,162 @@ def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: # Airflow Listener Plugin -class AiiDATaskGroupIntegrationListener: - """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" - - # NOTE: Apparently this is never triggered when using `dag.test` - # Hence, the WorkChainNode is only ever created once the DAG has run through fully - # use instead `on_task_instance_[running,success]` triggers - @hookimpl - def on_dag_run_running(self, dag_run: DagRun, msg: str): - """Called when a DAG run enters the running state.""" - logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") - - import ipdb; ipdb.set_trace() - if _should_integrate_dag_with_aiida(dag_run): - logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") - try: - _create_workchain_node_with_inputs(dag_run) - except Exception as e: - logger.error( - f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True - ) - - @hookimpl - def on_dag_run_success(self, dag_run: DagRun, msg: str): - """Called when a DAG run completes successfully.""" - logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") - - if _should_integrate_dag_with_aiida(dag_run): - logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") - try: - _finalize_workchain_node_with_outputs(dag_run) - except Exception as e: - logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) - - @hookimpl - def on_dag_run_failed(self, dag_run: DagRun, msg: str): - """Called when a DAG run fails.""" - logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") - # Optionally store failed runs in AiiDA with appropriate exit status - +# class AiiDATaskGroupIntegrationListener: +# """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" +# +# # NOTE: Apparently this is never triggered when using `dag.test` +# # Hence, the WorkChainNode is only ever created once the DAG has run through fully +# # use instead `on_task_instance_[running,success]` triggers +# @hookimpl +# def on_dag_run_running(self, dag_run: DagRun, msg: str): +# """Called when a DAG run enters the running state.""" +# logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") +# +# import ipdb; ipdb.set_trace() +# if _should_integrate_dag_with_aiida(dag_run): +# logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") +# try: +# _create_workchain_node_with_inputs(dag_run) +# except Exception as e: +# logger.error( +# f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True +# ) +# +# @hookimpl +# def on_dag_run_success(self, dag_run: DagRun, msg: str): +# """Called when a DAG run completes successfully.""" +# logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") +# +# if _should_integrate_dag_with_aiida(dag_run): +# logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") +# try: +# _finalize_workchain_node_with_outputs(dag_run) +# except Exception as e: +# logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) +# +# @hookimpl +# def on_dag_run_failed(self, dag_run: DagRun, msg: str): +# """Called when a DAG run fails.""" +# logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") +# # Optionally store failed runs in AiiDA with appropriate exit status + + +# class AiiDATaskGroupIntegrationListener: +# """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" +# +# def __init__(self): +# self.workchain_nodes = {} # Cache: run_id -> WorkChainNode +# +# @hookimpl +# def on_dag_run_running(self, dag_run: DagRun, msg: str): +# """Called when a DAG run enters the running state (NOT called in test mode).""" +# logger.info(f"[HOOK] on_dag_run_running: {dag_run.dag_id}/{dag_run.run_id}") +# +# if _should_integrate_dag_with_aiida(dag_run): +# logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") +# try: +# wc_node = _create_workchain_node_with_inputs(dag_run) +# self.workchain_nodes[dag_run.run_id] = wc_node +# except Exception as e: +# logger.error( +# f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True +# ) +# +# @hookimpl +# def on_task_instance_running(self, previous_state, task_instance: TaskInstance): +# """Called when a task instance starts running.""" +# # Create WorkChainNode when the FIRST task starts (since on_dag_run_running doesn't work in test mode) +# dag_run = task_instance.dag_run +# +# if ( +# dag_run.run_id not in self.workchain_nodes +# and _should_integrate_dag_with_aiida(dag_run) +# ): +# logger.info( +# f"[HOOK] Creating WorkChainNode on first task for DAG {dag_run.dag_id}" +# ) +# try: +# wc_node = _create_workchain_node_with_inputs(dag_run) +# self.workchain_nodes[dag_run.run_id] = wc_node +# except Exception as e: +# logger.error( +# f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True +# ) +# +# @hookimpl +# def on_task_instance_success(self, previous_state, task_instance: TaskInstance): +# """Called when a task instance succeeds.""" +# dag_run = task_instance.dag_run +# +# # Check if this is a CalcJob parse task +# if should_create_calcjob_node_for_taskgroup(task_instance): +# logger.info(f"[HOOK] CalcJob completed: {task_instance.task_id}") +# +# # Get or create the parent WorkChainNode +# wc_node = self.workchain_nodes.get(dag_run.run_id) +# if not wc_node and _should_integrate_dag_with_aiida(dag_run): +# # Fallback: create it now if it doesn't exist +# logger.warning(f"WorkChainNode not found in cache, creating now") +# wc_node = _create_workchain_node_with_inputs(dag_run) +# self.workchain_nodes[dag_run.run_id] = wc_node +# +# if wc_node: +# try: +# _create_calcjob_node_from_taskgroup(task_instance, wc_node, dag_run) +# except Exception as e: +# logger.error(f"Failed to create CalcJobNode: {e}", exc_info=True) +# +# @hookimpl +# def on_dag_run_success(self, dag_run: DagRun, msg: str): +# """Called when a DAG run completes successfully.""" +# logger.info(f"[HOOK] on_dag_run_success: {dag_run.dag_id}/{dag_run.run_id}") +# +# if _should_integrate_dag_with_aiida(dag_run): +# # Get the WorkChainNode +# wc_node = self.workchain_nodes.get(dag_run.run_id) +# +# if not wc_node: +# # Fallback for non-test mode where we might have missed it +# from aiida.orm import QueryBuilder +# +# qb = QueryBuilder() +# qb.append( +# orm.WorkChainNode, +# filters={"extras.airflow_run_id": dag_run.run_id}, +# ) +# results = qb.all() +# +# if results: +# wc_node = results[0][0] +# else: +# logger.warning(f"Creating WorkChainNode at end (shouldn't happen)") +# wc_node = _create_workchain_node_with_inputs(dag_run) +# +# # Finalize the WorkChainNode +# try: +# wc_node.set_process_state("finished") +# wc_node.set_exit_status(0) +# logger.info( +# f"Finalized WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}" +# ) +# +# # Clean up cache +# self.workchain_nodes.pop(dag_run.run_id, None) +# except Exception as e: +# logger.error(f"Failed to finalize WorkChainNode: {e}", exc_info=True) +# +# @hookimpl +# def on_dag_run_failed(self, dag_run: DagRun, msg: str): +# """Called when a DAG run fails.""" +# logger.info(f"[HOOK] on_dag_run_failed: {dag_run.dag_id}/{dag_run.run_id}") +# +# if _should_integrate_dag_with_aiida(dag_run): +# wc_node = self.workchain_nodes.get(dag_run.run_id) +# if wc_node: +# wc_node.set_process_state("excepted") +# wc_node.set_exit_status(1) +# self.workchain_nodes.pop(dag_run.run_id, None) +# # Create listener instance aiida_taskgroup_listener = AiiDATaskGroupIntegrationListener() From 4aad408fd5d46f6260b83942b27a236c0a82a5ca Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 13 Oct 2025 13:09:00 +0200 Subject: [PATCH 10/11] wip; logic in `on_dag_run_success` hookimpl --- .../plugins/aiida_dag_run_listener.py | 278 +++++------------- 1 file changed, 72 insertions(+), 206 deletions(-) diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index 8a25ee2..a8dbbcd 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -12,10 +12,6 @@ from aiida.common.links import LinkType import json -# Add dags directory to path for CalcJobTaskGroup import -# sys.path.append('/home/geiger_j/aiida_projects/aiida-airflow/git-repos/airflow-prototype/dags/') -# from calcjob_inheritance import CalcJobTaskGroup - load_profile() logger = logging.getLogger(__name__) @@ -31,7 +27,6 @@ def _param_to_python(param) -> Any: Returns: Python native value (int, float, bool, str, dict, list, etc.) """ - # Check if it's a Param object if not isinstance(param, (Param, ModelsParam)): return param @@ -50,35 +45,25 @@ def _param_to_python(param) -> Any: except (ValueError, TypeError): logger.warning(f"Could not convert Param value '{actual_value}' to int") return actual_value - elif param_type == "number": try: return float(actual_value) except (ValueError, TypeError): logger.warning(f"Could not convert Param value '{actual_value}' to float") return actual_value - elif param_type == "boolean": if isinstance(actual_value, bool): return actual_value - # Handle string representations if isinstance(actual_value, str): return actual_value.lower() in ("true", "1", "yes", "on") return bool(actual_value) - elif param_type == "string": return str(actual_value) - elif param_type == "object": - # Should already be a dict return actual_value if isinstance(actual_value, dict) else {} - elif param_type == "array": - # Should already be a list return actual_value if isinstance(actual_value, (list, tuple)) else [] - else: - # No type specified or unknown type - return as-is return actual_value @@ -101,17 +86,14 @@ def _convert_to_aiida_data(value: Any) -> Optional[orm.Data]: return orm.Float(value) elif isinstance(value, str): return orm.Str(value) - # Handle collections - store as Dict or List nodes elif isinstance(value, dict): return orm.Dict(dict=value) elif isinstance(value, (list, tuple)): return orm.List(list=list(value)) - # Handle Path objects elif isinstance(value, Path): return orm.Str(str(value)) - # For complex objects, try JSON serialization else: try: @@ -181,8 +163,15 @@ def should_create_calcjob_node_for_taskgroup(task_instance: TaskInstance) -> boo # Verify parent group exists and has the expected structure group_id = task_instance.task_id.rsplit(".parse", 1)[0] - # Check if this is likely a CalcJobTaskGroup by looking for sibling tasks - dag_run = task_instance.dag_run + # Get task instances from the dag_run + # Note: In the success hook, we have access to the full dag_run + from airflow import settings + session = settings.Session() + dag_run = session.query(DagRun).filter( + DagRun.dag_id == task_instance.dag_id, + DagRun.run_id == task_instance.run_id + ).first() + if dag_run: task_instances = dag_run.get_task_instances() # Look for the prepare task in the same group @@ -302,9 +291,6 @@ def _create_calcjob_node_from_taskgroup( # Set process type to the group ID cj_node.set_process_type(group_id) cj_node.set_process_state("finished") - import ipdb - - ipdb.set_trace() cj_node.set_process_label("AirflowCalcJob") # Determine exit status from parse task result @@ -364,9 +350,6 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: # Set process type to the DAG ID wc_node.set_process_type(dag_run.dag_id) - import ipdb - - ipdb.set_trace() wc_node.set_process_label("AirflowWorkChain") # Store DAG parameters @@ -393,32 +376,14 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: return wc_node -def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: +def _finalize_workchain_node(wc_node: orm.WorkChainNode, dag_run: DagRun) -> None: """ - Find the WorkChainNode for a completed DAG run and add outputs (CalcJobNodes from TaskGroups). - If the WorkChainNode doesn't exist yet, create it first. + Finalize the WorkChainNode and create CalcJobNodes for all completed task groups. + + Args: + wc_node: The WorkChainNode to finalize + dag_run: The completed DAG run """ - from aiida.orm import QueryBuilder - - # Try to find the WorkChainNode created in on_dag_run_running - qb = QueryBuilder() - qb.append( - orm.WorkChainNode, - filters={"extras.airflow_run_id": dag_run.run_id}, - tag="workchain", - ) - results = qb.all() - - if not results: - # WorkChainNode doesn't exist yet - create it now with inputs - logger.warning( - f"WorkChainNode not found for run_id {dag_run.run_id}. " - f"Creating it now (on_dag_run_running may not have been called)." - ) - wc_node = _create_workchain_node_with_inputs(dag_run) - else: - wc_node = results[0][0] - # Update process state to finished wc_node.set_process_state("finished") wc_node.set_exit_status(0) @@ -433,162 +398,63 @@ def _finalize_workchain_node_with_outputs(dag_run: DagRun) -> None: # Airflow Listener Plugin -# class AiiDATaskGroupIntegrationListener: -# """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" -# -# # NOTE: Apparently this is never triggered when using `dag.test` -# # Hence, the WorkChainNode is only ever created once the DAG has run through fully -# # use instead `on_task_instance_[running,success]` triggers -# @hookimpl -# def on_dag_run_running(self, dag_run: DagRun, msg: str): -# """Called when a DAG run enters the running state.""" -# logger.info(f"DAG run started: {dag_run.dag_id}/{dag_run.run_id}") -# -# import ipdb; ipdb.set_trace() -# if _should_integrate_dag_with_aiida(dag_run): -# logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") -# try: -# _create_workchain_node_with_inputs(dag_run) -# except Exception as e: -# logger.error( -# f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True -# ) -# -# @hookimpl -# def on_dag_run_success(self, dag_run: DagRun, msg: str): -# """Called when a DAG run completes successfully.""" -# logger.info(f"DAG run succeeded: {dag_run.dag_id}/{dag_run.run_id}") -# -# if _should_integrate_dag_with_aiida(dag_run): -# logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") -# try: -# _finalize_workchain_node_with_outputs(dag_run) -# except Exception as e: -# logger.error(f"Failed to finalize AiiDA provenance: {e}", exc_info=True) -# -# @hookimpl -# def on_dag_run_failed(self, dag_run: DagRun, msg: str): -# """Called when a DAG run fails.""" -# logger.info(f"DAG run failed: {dag_run.dag_id}/{dag_run.run_id}") -# # Optionally store failed runs in AiiDA with appropriate exit status - - -# class AiiDATaskGroupIntegrationListener: -# """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" -# -# def __init__(self): -# self.workchain_nodes = {} # Cache: run_id -> WorkChainNode -# -# @hookimpl -# def on_dag_run_running(self, dag_run: DagRun, msg: str): -# """Called when a DAG run enters the running state (NOT called in test mode).""" -# logger.info(f"[HOOK] on_dag_run_running: {dag_run.dag_id}/{dag_run.run_id}") -# -# if _should_integrate_dag_with_aiida(dag_run): -# logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") -# try: -# wc_node = _create_workchain_node_with_inputs(dag_run) -# self.workchain_nodes[dag_run.run_id] = wc_node -# except Exception as e: -# logger.error( -# f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True -# ) -# -# @hookimpl -# def on_task_instance_running(self, previous_state, task_instance: TaskInstance): -# """Called when a task instance starts running.""" -# # Create WorkChainNode when the FIRST task starts (since on_dag_run_running doesn't work in test mode) -# dag_run = task_instance.dag_run -# -# if ( -# dag_run.run_id not in self.workchain_nodes -# and _should_integrate_dag_with_aiida(dag_run) -# ): -# logger.info( -# f"[HOOK] Creating WorkChainNode on first task for DAG {dag_run.dag_id}" -# ) -# try: -# wc_node = _create_workchain_node_with_inputs(dag_run) -# self.workchain_nodes[dag_run.run_id] = wc_node -# except Exception as e: -# logger.error( -# f"Failed to create AiiDA WorkChainNode: {e}", exc_info=True -# ) -# -# @hookimpl -# def on_task_instance_success(self, previous_state, task_instance: TaskInstance): -# """Called when a task instance succeeds.""" -# dag_run = task_instance.dag_run -# -# # Check if this is a CalcJob parse task -# if should_create_calcjob_node_for_taskgroup(task_instance): -# logger.info(f"[HOOK] CalcJob completed: {task_instance.task_id}") -# -# # Get or create the parent WorkChainNode -# wc_node = self.workchain_nodes.get(dag_run.run_id) -# if not wc_node and _should_integrate_dag_with_aiida(dag_run): -# # Fallback: create it now if it doesn't exist -# logger.warning(f"WorkChainNode not found in cache, creating now") -# wc_node = _create_workchain_node_with_inputs(dag_run) -# self.workchain_nodes[dag_run.run_id] = wc_node -# -# if wc_node: -# try: -# _create_calcjob_node_from_taskgroup(task_instance, wc_node, dag_run) -# except Exception as e: -# logger.error(f"Failed to create CalcJobNode: {e}", exc_info=True) -# -# @hookimpl -# def on_dag_run_success(self, dag_run: DagRun, msg: str): -# """Called when a DAG run completes successfully.""" -# logger.info(f"[HOOK] on_dag_run_success: {dag_run.dag_id}/{dag_run.run_id}") -# -# if _should_integrate_dag_with_aiida(dag_run): -# # Get the WorkChainNode -# wc_node = self.workchain_nodes.get(dag_run.run_id) -# -# if not wc_node: -# # Fallback for non-test mode where we might have missed it -# from aiida.orm import QueryBuilder -# -# qb = QueryBuilder() -# qb.append( -# orm.WorkChainNode, -# filters={"extras.airflow_run_id": dag_run.run_id}, -# ) -# results = qb.all() -# -# if results: -# wc_node = results[0][0] -# else: -# logger.warning(f"Creating WorkChainNode at end (shouldn't happen)") -# wc_node = _create_workchain_node_with_inputs(dag_run) -# -# # Finalize the WorkChainNode -# try: -# wc_node.set_process_state("finished") -# wc_node.set_exit_status(0) -# logger.info( -# f"Finalized WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}" -# ) -# -# # Clean up cache -# self.workchain_nodes.pop(dag_run.run_id, None) -# except Exception as e: -# logger.error(f"Failed to finalize WorkChainNode: {e}", exc_info=True) -# -# @hookimpl -# def on_dag_run_failed(self, dag_run: DagRun, msg: str): -# """Called when a DAG run fails.""" -# logger.info(f"[HOOK] on_dag_run_failed: {dag_run.dag_id}/{dag_run.run_id}") -# -# if _should_integrate_dag_with_aiida(dag_run): -# wc_node = self.workchain_nodes.get(dag_run.run_id) -# if wc_node: -# wc_node.set_process_state("excepted") -# wc_node.set_exit_status(1) -# self.workchain_nodes.pop(dag_run.run_id, None) -# +class AiiDATaskGroupIntegrationListener: + """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" + + @hookimpl + def on_dag_run_success(self, dag_run: DagRun, msg: str): + """ + Called when a DAG run completes successfully. + + Creates the WorkChainNode with inputs, then creates CalcJobNodes for all + completed task groups, and finally finalizes the WorkChainNode. + """ + logger.info(f"[HOOK] on_dag_run_success: {dag_run.dag_id}/{dag_run.run_id}") + + if not _should_integrate_dag_with_aiida(dag_run): + logger.debug(f"DAG {dag_run.dag_id} not tagged for AiiDA integration") + return + + try: + logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") + wc_node = _create_workchain_node_with_inputs(dag_run) + + logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") + _finalize_workchain_node(wc_node, dag_run) + + logger.info( + f"Successfully integrated DAG {dag_run.dag_id} into AiiDA provenance" + ) + except Exception as e: + logger.error( + f"Failed to integrate DAG {dag_run.dag_id} into AiiDA: {e}", + exc_info=True + ) + + @hookimpl + def on_dag_run_failed(self, dag_run: DagRun, msg: str): + """ + Called when a DAG run fails. + + Creates a failed WorkChainNode for provenance tracking. + """ + logger.info(f"[HOOK] on_dag_run_failed: {dag_run.dag_id}/{dag_run.run_id}") + + if not _should_integrate_dag_with_aiida(dag_run): + return + + try: + # Create WorkChainNode for failed run + wc_node = _create_workchain_node_with_inputs(dag_run) + wc_node.set_process_state("excepted") + wc_node.set_exit_status(1) + logger.info(f"Created failed WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}") + except Exception as e: + logger.error( + f"Failed to create WorkChainNode for failed DAG: {e}", + exc_info=True + ) + # Create listener instance aiida_taskgroup_listener = AiiDATaskGroupIntegrationListener() From 7e325d60d7263a3e3f7343ffecac41575834dec1 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 13 Oct 2025 13:32:51 +0200 Subject: [PATCH 11/11] wip --- run_arithmetic_dag.py | 15 +++---- .../plugins/aiida_dag_run_listener.py | 39 ++++++++++++------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/run_arithmetic_dag.py b/run_arithmetic_dag.py index d59d498..da11ed1 100644 --- a/run_arithmetic_dag.py +++ b/run_arithmetic_dag.py @@ -1,5 +1,6 @@ from pathlib import Path import os +from airflow.api.client.local_client import Client # Set AIRFLOW__CORE__DAGS_FOLDER to include example_dags dag_folder = str( @@ -31,12 +32,12 @@ # Use test mode with execution_date to avoid serialization issues -dag.test( - run_conf=conf, - # execution_date=datetime.now(), - use_executor=False, # Run tasks sequentially in the same process -) +# dag.test( +# run_conf=conf, +# # execution_date=datetime.now(), +# use_executor=False, # Run tasks sequentially in the same process +# ) # Trigger DAG using API client (requires scheduler to be running) -# client: Client = Client() -# client.trigger_dag(dag_id="arithmetic_add_multiply", conf=conf) +client: Client = Client() +client.trigger_dag(dag_id="arithmetic_add_multiply", conf=conf) diff --git a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py index a8dbbcd..93ec6b8 100644 --- a/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py +++ b/src/airflow_provider_aiida/plugins/aiida_dag_run_listener.py @@ -1,5 +1,4 @@ import logging -import sys from pathlib import Path from typing import Any, Dict, Optional @@ -166,12 +165,17 @@ def should_create_calcjob_node_for_taskgroup(task_instance: TaskInstance) -> boo # Get task instances from the dag_run # Note: In the success hook, we have access to the full dag_run from airflow import settings + session = settings.Session() - dag_run = session.query(DagRun).filter( - DagRun.dag_id == task_instance.dag_id, - DagRun.run_id == task_instance.run_id - ).first() - + dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == task_instance.dag_id, + DagRun.run_id == task_instance.run_id, + ) + .first() + ) + if dag_run: task_instances = dag_run.get_task_instances() # Look for the prepare task in the same group @@ -379,7 +383,7 @@ def _create_workchain_node_with_inputs(dag_run: DagRun) -> orm.WorkChainNode: def _finalize_workchain_node(wc_node: orm.WorkChainNode, dag_run: DagRun) -> None: """ Finalize the WorkChainNode and create CalcJobNodes for all completed task groups. - + Args: wc_node: The WorkChainNode to finalize dag_run: The completed DAG run @@ -401,11 +405,15 @@ def _finalize_workchain_node(wc_node: orm.WorkChainNode, dag_run: DagRun) -> Non class AiiDATaskGroupIntegrationListener: """Listener that integrates Airflow CalcJobTaskGroups with AiiDA provenance""" + @hookimpl + def on_dag_run_running(self, dag_run: DagRun, msg: str): + breakpoint() + @hookimpl def on_dag_run_success(self, dag_run: DagRun, msg: str): """ Called when a DAG run completes successfully. - + Creates the WorkChainNode with inputs, then creates CalcJobNodes for all completed task groups, and finally finalizes the WorkChainNode. """ @@ -418,24 +426,24 @@ def on_dag_run_success(self, dag_run: DagRun, msg: str): try: logger.info(f"Creating WorkChainNode for DAG {dag_run.dag_id}") wc_node = _create_workchain_node_with_inputs(dag_run) - + logger.info(f"Finalizing WorkChainNode for DAG {dag_run.dag_id}") _finalize_workchain_node(wc_node, dag_run) - + logger.info( f"Successfully integrated DAG {dag_run.dag_id} into AiiDA provenance" ) except Exception as e: logger.error( f"Failed to integrate DAG {dag_run.dag_id} into AiiDA: {e}", - exc_info=True + exc_info=True, ) @hookimpl def on_dag_run_failed(self, dag_run: DagRun, msg: str): """ Called when a DAG run fails. - + Creates a failed WorkChainNode for provenance tracking. """ logger.info(f"[HOOK] on_dag_run_failed: {dag_run.dag_id}/{dag_run.run_id}") @@ -448,11 +456,12 @@ def on_dag_run_failed(self, dag_run: DagRun, msg: str): wc_node = _create_workchain_node_with_inputs(dag_run) wc_node.set_process_state("excepted") wc_node.set_exit_status(1) - logger.info(f"Created failed WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}") + logger.info( + f"Created failed WorkChainNode {wc_node.pk} for DAG {dag_run.dag_id}" + ) except Exception as e: logger.error( - f"Failed to create WorkChainNode for failed DAG: {e}", - exc_info=True + f"Failed to create WorkChainNode for failed DAG: {e}", exc_info=True )