Skip to content

Commit

Permalink
Feature/taskflow api academic observatory (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdddog authored Apr 22, 2024
1 parent 08b19d9 commit 830079e
Show file tree
Hide file tree
Showing 11 changed files with 725 additions and 38 deletions.
20 changes: 20 additions & 0 deletions observatory-platform/observatory/platform/dags/load_dags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ:
# https://airflow.apache.org/docs/stable/faq.html

from observatory.platform.refactor.workflow import load_dags_from_config

load_dags_from_config()
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@

from observatory.platform.airflow import fetch_workflows, make_workflow
from observatory.platform.observatory_config import Workflow
from observatory.platform.workflows.workflow import Workflow as ObservatoryWorkflow

# Load DAGs
workflows: List[Workflow] = fetch_workflows()
for config in workflows:
logging.info(f"Making Workflow: {config.name}, dag_id={config.dag_id}")
workflow = make_workflow(config)
dag = workflow.make_dag()

logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}")
globals()[workflow.dag_id] = dag
if isinstance(workflow, ObservatoryWorkflow):
dag = workflow.make_dag()
logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}")
globals()[workflow.dag_id] = dag
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,11 @@ def add_connection(self, conn: Connection):
self.session.add(conn)
self.session.commit()

def run_task(self, task_id: str) -> TaskInstance:
def run_task(self, task_id: str, map_index: int = -1) -> TaskInstance:
"""Run an Airflow task.
:param task_id: the Airflow task identifier.
:param map_index: the map index if the task is a daynamic task
:return: None.
"""

Expand All @@ -448,9 +449,29 @@ def run_task(self, task_id: str) -> TaskInstance:
dag = self.dag_run.dag
run_id = self.dag_run.run_id
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()

# TODO: remove this when this issue fixed / PR merged: https://github.com/apache/airflow/issues/34023#issuecomment-1705761692
# https://github.com/apache/airflow/pull/36462
ignore_task_deps = False
if map_index > -1:
ignore_task_deps = True

ti.run(ignore_task_deps=ignore_task_deps)

return ti

def skip_task(self, task_id: str, map_index: int = -1) -> TaskInstance:

assert self.dag_run is not None, "with create_dag_run must be called before run_task"

dag = self.dag_run.dag
run_id = self.dag_run.run_id
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()
ti.run(ignore_ti_state=True)
ti.set_state(State.SKIPPED)

return ti

Expand Down Expand Up @@ -793,12 +814,16 @@ def assert_dag_structure(self, expected: Dict, dag: DAG):

expected_keys = expected.keys()
actual_keys = dag.task_dict.keys()
diff = set(expected_keys) - set(actual_keys)
self.assertEqual(expected_keys, actual_keys)

for task_id, downstream_list in expected.items():
print(task_id)
self.assertTrue(dag.has_task(task_id))
task = dag.get_task(task_id)
self.assertEqual(set(downstream_list), task.downstream_task_ids)
expected = set(downstream_list)
actual = task.downstream_task_ids
self.assertEqual(expected, actual)

def assert_dag_load(self, dag_id: str, dag_file: str):
"""Assert that the given DAG loads from a DagBag.
Expand All @@ -814,7 +839,7 @@ def assert_dag_load(self, dag_id: str, dag_file: str):

shutil.copy(dag_file, os.path.join(dag_folder, os.path.basename(dag_file)))

dag_bag = DagBag(dag_folder=dag_folder)
dag_bag = DagBag(dag_folder=dag_folder, include_examples=False)

if dag_bag.import_errors != {}:
logging.error(f"DagBag errors: {dag_bag.import_errors}")
Expand All @@ -837,7 +862,7 @@ def assert_dag_load_from_config(self, dag_id: str):
:return: None.
"""

self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_workflows.py"))
self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_dags_legacy.py"))

def assert_blob_exists(self, bucket_id: str, blob_name: str):
"""Assert whether a blob exists or not.
Expand Down Expand Up @@ -962,13 +987,13 @@ def assert_file_integrity(self, file_path: str, expected_hash: str, algorithm: s
self.assertEqual(expected_hash, actual_hash)

def assert_cleanup(self, workflow_folder: str):
"""Assert that the download, extracted and transformed folders were cleaned up.
"""Assert that the files in the workflow_folder folder was cleaned up.
:param workflow_folder: the path to the DAGs download folder.
:return: None.
"""

self.assertFalse(os.path.exists(workflow_folder))
self.assertTrue(len(os.listdir(workflow_folder)) == 0)

def setup_mock_file_download(
self, uri: str, file_path: str, headers: Dict = None, method: str = httpretty.GET
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions observatory-platform/observatory/platform/refactor/sensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2020, 2021 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Author: Tuan Chien, Keegan Smith, Jamie Diprose

from __future__ import annotations

from datetime import timedelta
from functools import partial
from typing import Callable, List, Optional

import pendulum
from airflow.models import DagRun
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.db import provide_session
from sqlalchemy.orm.scoping import scoped_session


class DagCompleteSensor(ExternalTaskSensor):
"""
A sensor that awaits the completion of an external dag by default. Wait functionality can be customised by
providing a different execution_date_fn.
The sensor checks for completion of a dag with "external_dag_id" on the logical date returned by the
execution_date_fn.
"""

def __init__(
self,
task_id: str,
external_dag_id: str,
mode: str = "reschedule",
poke_interval: int = 1200, # Check if dag run is ready every 20 minutes
timeout: int = int(timedelta(days=1).total_seconds()), # Sensor will fail after 1 day of waiting
check_existence: bool = True,
execution_date_fn: Optional[Callable] = None,
**kwargs,
):
"""
:param task_id: the id of the sensor task to create
:param external_dag_id: the id of the external dag to check
:param mode: The mode of the scheduler. Can be reschedule or poke.
:param poke_interval: how often to check if the external dag run is complete
:param timeout: how long to check before the sensor fails
:param check_existence: whether to check that the provided dag_id exists
:param execution_date_fn: a function that returns the logical date(s) of the external DAG runs to query for,
since you need a logical date and a DAG ID to find a particular DAG run to wait for.
"""

if execution_date_fn is None:
execution_date_fn = partial(get_logical_dates, external_dag_id)

super().__init__(
task_id=task_id,
external_dag_id=external_dag_id,
mode=mode,
poke_interval=poke_interval,
timeout=timeout,
check_existence=check_existence,
execution_date_fn=execution_date_fn,
**kwargs,
)


@provide_session
def get_logical_dates(
external_dag_id: str, logical_date: pendulum.DateTime, session: scoped_session = None, **context
) -> List[pendulum.DateTime]:
"""Get the logical dates for a given external dag that fall between and returns its data_interval_start (logical date)
:param external_dag_id: the DAG ID of the external DAG we are waiting for.
:param logical_date: the logic date of the waiting DAG.
:param session: the SQL Alchemy session.
:param context: the Airflow context.
:return: the last logical date of the external DAG that falls before the data interval end of the waiting DAG.
"""

data_interval_end = context["data_interval_end"]
dag_runs = (
session.query(DagRun)
.filter(
DagRun.dag_id == external_dag_id,
DagRun.data_interval_end <= data_interval_end,
)
.all()
)
dates = [d.logical_date for d in dag_runs]
dates.sort(reverse=True)

# If more than 1 date return first date
if len(dates) >= 2:
dates = [dates[0]]

return dates
77 changes: 77 additions & 0 deletions observatory-platform/observatory/platform/refactor/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2020-2023 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import logging
from typing import List, Optional

import airflow
from airflow.decorators import task
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Variable


@task
def check_dependencies(airflow_vars: Optional[List[str]] = None, airflow_conns: Optional[List[str]] = None, **context):
"""Checks if the given Airflow Variables and Connections exist.
:param airflow_vars: the Airflow Variables to check exist.
:param airflow_conns: the Airflow Connections to check exist.
:return: None.
"""

vars_valid = True
conns_valid = True
if airflow_vars:
vars_valid = check_variables(*airflow_vars)
if airflow_conns:
conns_valid = check_connections(*airflow_conns)

if not vars_valid or not conns_valid:
raise AirflowNotFoundException("Required variables or connections are missing")


def check_variables(*variables):
"""Checks whether all given airflow variables exist.
:param variables: name of airflow variable
:return: True if all variables are valid
"""
is_valid = True
for name in variables:
try:
Variable.get(name)
except AirflowNotFoundException:
logging.error(f"Airflow variable '{name}' not set.")
is_valid = False
return is_valid


def check_connections(*connections):
"""Checks whether all given airflow connections exist.
:param connections: name of airflow connection
:return: True if all connections are valid
"""
is_valid = True
for name in connections:
try:
BaseHook.get_connection(name)
except airflow.exceptions.AirflowNotFoundException:
logging.error(f"Airflow connection '{name}' not set.")
is_valid = False
return is_valid
Loading

0 comments on commit 830079e

Please sign in to comment.