diff --git a/tests/dags/test_dag_example.py b/tests/dags/test_dag_example.py index 43eb535..451d1e6 100644 --- a/tests/dags/test_dag_example.py +++ b/tests/dags/test_dag_example.py @@ -49,7 +49,7 @@ def test_dag_runs(setup_airflow_db, dag_id, dag, fileloc): print(f"Testing DAG: {dag_id}, located at: {fileloc}") assert dag is not None, f"DAG {dag_id} not found!" try: - test_utils.run_dag(dag) + dag.test() except Exception as e: print(f"Error running DAG {dag_id}: {e}") raise e \ No newline at end of file diff --git a/tests/dags/utils.py b/tests/dags/utils.py index 335a86d..74a34e1 100644 --- a/tests/dags/utils.py +++ b/tests/dags/utils.py @@ -9,14 +9,13 @@ from airflow.exceptions import AirflowSkipException from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance, TaskReturnCode +from airflow.models.taskinstance import TaskInstance from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.utils import timezone from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType from sqlalchemy.orm.session import Session -from airflow.models.operator import Operator log = logging.getLogger(__name__) @@ -25,6 +24,8 @@ def run_dag(dag: DAG, conn_file_path: str | None = None) -> DagRun: return test_dag(dag=dag, conn_file_path=conn_file_path) +# DAG.test() was added in Airflow version 2.5.0. And to test on older Airflow versions, we need to copy the +# implementation here. @provide_session def test_dag( dag, @@ -71,12 +72,24 @@ def test_dag( tasks = dag.task_dict dag.log.debug("starting dagrun") + # Instead of starting a scheduler, we run the minimal loop possible to check + # for task readiness and dependency management. This is notably faster + # than creating a BackfillJob and allows us to surface logs to the user while dr.state == State.RUNNING: schedulable_tis, _ = dr.update_state(session=session) for ti in schedulable_tis: add_logger_if_needed(dag, ti) ti.task = tasks[ti.task_id] _run_task(ti, session=session) + + # Add handling for DEFERRED tasks + deferred_tis = [ti for ti in dr.get_task_instances() if ti.state == State.DEFERRED] + if deferred_tis: + # Simulate trigger event + for ti in deferred_tis: + ti.set_state(State.SCHEDULED, session=session) + ti.task = tasks[ti.task_id] + _run_task(ti, session=session) if conn_file_path or variable_file_path: # Remove the local variables we have added to the secrets_backend_list @@ -94,6 +107,7 @@ def add_logger_if_needed(dag: DAG, ti: TaskInstance): in the command line, rather than needing to search for a log file. Args: ti: The taskinstance that will receive a logger + """ logging_format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") handler = logging.StreamHandler(sys.stdout) @@ -104,6 +118,7 @@ def add_logger_if_needed(dag: DAG, ti: TaskInstance): dag.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) ti.log.addHandler(handler) + def _run_task(ti: TaskInstance, session): """ Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of @@ -119,9 +134,7 @@ def _run_task(ti: TaskInstance, session): else: log.info("Running task %s", ti.task_id) try: - task_status = ti._run_raw_task(session=session) - if task_status == TaskReturnCode.DEFERRED: - ti._run_raw_task(session=session) + ti._run_raw_task(session=session) session.flush() log.info("%s ran successfully!", ti.task_id) except AirflowSkipException: