Skip to content

Commit

Permalink
trying with dag.test()
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 11, 2024
1 parent f379ae9 commit b1fb649
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/dags/test_dag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_dag_runs(setup_airflow_db, dag_id, dag, fileloc):
print(f"Testing DAG: {dag_id}, located at: {fileloc}")
assert dag is not None, f"DAG {dag_id} not found!"
try:
test_utils.run_dag(dag)
dag.test()
except Exception as e:
print(f"Error running DAG {dag_id}: {e}")
raise e
23 changes: 18 additions & 5 deletions tests/dags/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from airflow.exceptions import AirflowSkipException
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance, TaskReturnCode
from airflow.models.taskinstance import TaskInstance
from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
from sqlalchemy.orm.session import Session
from airflow.models.operator import Operator

log = logging.getLogger(__name__)

Expand All @@ -25,6 +24,8 @@ def run_dag(dag: DAG, conn_file_path: str | None = None) -> DagRun:
return test_dag(dag=dag, conn_file_path=conn_file_path)


# DAG.test() was added in Airflow version 2.5.0. And to test on older Airflow versions, we need to copy the
# implementation here.
@provide_session
def test_dag(
dag,
Expand Down Expand Up @@ -71,12 +72,24 @@ def test_dag(

tasks = dag.task_dict
dag.log.debug("starting dagrun")
# Instead of starting a scheduler, we run the minimal loop possible to check
# for task readiness and dependency management. This is notably faster
# than creating a BackfillJob and allows us to surface logs to the user
while dr.state == State.RUNNING:
schedulable_tis, _ = dr.update_state(session=session)
for ti in schedulable_tis:
add_logger_if_needed(dag, ti)
ti.task = tasks[ti.task_id]
_run_task(ti, session=session)

# Add handling for DEFERRED tasks
deferred_tis = [ti for ti in dr.get_task_instances() if ti.state == State.DEFERRED]
if deferred_tis:
# Simulate trigger event
for ti in deferred_tis:
ti.set_state(State.SCHEDULED, session=session)
ti.task = tasks[ti.task_id]
_run_task(ti, session=session)

if conn_file_path or variable_file_path:
# Remove the local variables we have added to the secrets_backend_list
Expand All @@ -94,6 +107,7 @@ def add_logger_if_needed(dag: DAG, ti: TaskInstance):
in the command line, rather than needing to search for a log file.
Args:
ti: The taskinstance that will receive a logger
"""
logging_format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
handler = logging.StreamHandler(sys.stdout)
Expand All @@ -104,6 +118,7 @@ def add_logger_if_needed(dag: DAG, ti: TaskInstance):
dag.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id)
ti.log.addHandler(handler)


def _run_task(ti: TaskInstance, session):
"""
Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of
Expand All @@ -119,9 +134,7 @@ def _run_task(ti: TaskInstance, session):
else:
log.info("Running task %s", ti.task_id)
try:
task_status = ti._run_raw_task(session=session)
if task_status == TaskReturnCode.DEFERRED:
ti._run_raw_task(session=session)
ti._run_raw_task(session=session)
session.flush()
log.info("%s ran successfully!", ti.task_id)
except AirflowSkipException:
Expand Down

0 comments on commit b1fb649

Please sign in to comment.