Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/taskflow api academic observatory #656

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions observatory-platform/observatory/platform/dags/load_dags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ:
# https://airflow.apache.org/docs/stable/faq.html

from observatory.platform.refactor.workflow import load_dags_from_config

load_dags_from_config()
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@

from observatory.platform.airflow import fetch_workflows, make_workflow
from observatory.platform.observatory_config import Workflow
from observatory.platform.workflows.workflow import Workflow as ObservatoryWorkflow

# Load DAGs
workflows: List[Workflow] = fetch_workflows()
for config in workflows:
logging.info(f"Making Workflow: {config.name}, dag_id={config.dag_id}")
workflow = make_workflow(config)
dag = workflow.make_dag()

logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}")
globals()[workflow.dag_id] = dag
if isinstance(workflow, ObservatoryWorkflow):
dag = workflow.make_dag()
logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}")
globals()[workflow.dag_id] = dag
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,11 @@ def add_connection(self, conn: Connection):
self.session.add(conn)
self.session.commit()

def run_task(self, task_id: str) -> TaskInstance:
def run_task(self, task_id: str, map_index: int = -1) -> TaskInstance:
"""Run an Airflow task.

:param task_id: the Airflow task identifier.
:param map_index: the map index if the task is a daynamic task
:return: None.
"""

Expand All @@ -448,9 +449,29 @@ def run_task(self, task_id: str) -> TaskInstance:
dag = self.dag_run.dag
run_id = self.dag_run.run_id
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()

# TODO: remove this when this issue fixed / PR merged: https://github.com/apache/airflow/issues/34023#issuecomment-1705761692
# https://github.com/apache/airflow/pull/36462
ignore_task_deps = False
if map_index > -1:
ignore_task_deps = True

ti.run(ignore_task_deps=ignore_task_deps)

return ti

def skip_task(self, task_id: str, map_index: int = -1) -> TaskInstance:

assert self.dag_run is not None, "with create_dag_run must be called before run_task"

dag = self.dag_run.dag
run_id = self.dag_run.run_id
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()
ti.run(ignore_ti_state=True)
ti.set_state(State.SKIPPED)

return ti

Expand Down Expand Up @@ -793,12 +814,16 @@ def assert_dag_structure(self, expected: Dict, dag: DAG):

expected_keys = expected.keys()
actual_keys = dag.task_dict.keys()
diff = set(expected_keys) - set(actual_keys)
self.assertEqual(expected_keys, actual_keys)

for task_id, downstream_list in expected.items():
print(task_id)
self.assertTrue(dag.has_task(task_id))
task = dag.get_task(task_id)
self.assertEqual(set(downstream_list), task.downstream_task_ids)
expected = set(downstream_list)
actual = task.downstream_task_ids
self.assertEqual(expected, actual)

def assert_dag_load(self, dag_id: str, dag_file: str):
"""Assert that the given DAG loads from a DagBag.
Expand All @@ -814,7 +839,7 @@ def assert_dag_load(self, dag_id: str, dag_file: str):

shutil.copy(dag_file, os.path.join(dag_folder, os.path.basename(dag_file)))

dag_bag = DagBag(dag_folder=dag_folder)
dag_bag = DagBag(dag_folder=dag_folder, include_examples=False)

if dag_bag.import_errors != {}:
logging.error(f"DagBag errors: {dag_bag.import_errors}")
Expand All @@ -837,7 +862,7 @@ def assert_dag_load_from_config(self, dag_id: str):
:return: None.
"""

self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_workflows.py"))
self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_dags_legacy.py"))

def assert_blob_exists(self, bucket_id: str, blob_name: str):
"""Assert whether a blob exists or not.
Expand Down Expand Up @@ -962,13 +987,13 @@ def assert_file_integrity(self, file_path: str, expected_hash: str, algorithm: s
self.assertEqual(expected_hash, actual_hash)

def assert_cleanup(self, workflow_folder: str):
"""Assert that the download, extracted and transformed folders were cleaned up.
"""Assert that the files in the workflow_folder folder was cleaned up.

:param workflow_folder: the path to the DAGs download folder.
:return: None.
"""

self.assertFalse(os.path.exists(workflow_folder))
self.assertTrue(len(os.listdir(workflow_folder)) == 0)

def setup_mock_file_download(
self, uri: str, file_path: str, headers: Dict = None, method: str = httpretty.GET
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions observatory-platform/observatory/platform/refactor/sensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2020, 2021 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Author: Tuan Chien, Keegan Smith, Jamie Diprose

from __future__ import annotations

from datetime import timedelta
from functools import partial
from typing import Callable, List, Optional

import pendulum
from airflow.models import DagRun
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.db import provide_session
from sqlalchemy.orm.scoping import scoped_session


class DagCompleteSensor(ExternalTaskSensor):
"""
A sensor that awaits the completion of an external dag by default. Wait functionality can be customised by
providing a different execution_date_fn.

The sensor checks for completion of a dag with "external_dag_id" on the logical date returned by the
execution_date_fn.
"""

def __init__(
self,
task_id: str,
external_dag_id: str,
mode: str = "reschedule",
poke_interval: int = 1200, # Check if dag run is ready every 20 minutes
timeout: int = int(timedelta(days=1).total_seconds()), # Sensor will fail after 1 day of waiting
check_existence: bool = True,
execution_date_fn: Optional[Callable] = None,
**kwargs,
):
"""
:param task_id: the id of the sensor task to create
:param external_dag_id: the id of the external dag to check
:param mode: The mode of the scheduler. Can be reschedule or poke.
:param poke_interval: how often to check if the external dag run is complete
:param timeout: how long to check before the sensor fails
:param check_existence: whether to check that the provided dag_id exists
:param execution_date_fn: a function that returns the logical date(s) of the external DAG runs to query for,
since you need a logical date and a DAG ID to find a particular DAG run to wait for.
"""

if execution_date_fn is None:
execution_date_fn = partial(get_logical_dates, external_dag_id)

super().__init__(
task_id=task_id,
external_dag_id=external_dag_id,
mode=mode,
poke_interval=poke_interval,
timeout=timeout,
check_existence=check_existence,
execution_date_fn=execution_date_fn,
**kwargs,
)


@provide_session
def get_logical_dates(
external_dag_id: str, logical_date: pendulum.DateTime, session: scoped_session = None, **context
) -> List[pendulum.DateTime]:
"""Get the logical dates for a given external dag that fall between and returns its data_interval_start (logical date)

:param external_dag_id: the DAG ID of the external DAG we are waiting for.
:param logical_date: the logic date of the waiting DAG.
:param session: the SQL Alchemy session.
:param context: the Airflow context.
:return: the last logical date of the external DAG that falls before the data interval end of the waiting DAG.
"""

data_interval_end = context["data_interval_end"]
dag_runs = (
session.query(DagRun)
.filter(
DagRun.dag_id == external_dag_id,
DagRun.data_interval_end <= data_interval_end,
)
.all()
)
dates = [d.logical_date for d in dag_runs]
dates.sort(reverse=True)

# If more than 1 date return first date
if len(dates) >= 2:
dates = [dates[0]]

return dates
77 changes: 77 additions & 0 deletions observatory-platform/observatory/platform/refactor/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2020-2023 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import logging
from typing import List, Optional

import airflow
from airflow.decorators import task
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Variable


@task
def check_dependencies(airflow_vars: Optional[List[str]] = None, airflow_conns: Optional[List[str]] = None, **context):
"""Checks if the given Airflow Variables and Connections exist.

:param airflow_vars: the Airflow Variables to check exist.
:param airflow_conns: the Airflow Connections to check exist.
:return: None.
"""

vars_valid = True
conns_valid = True
if airflow_vars:
vars_valid = check_variables(*airflow_vars)
if airflow_conns:
conns_valid = check_connections(*airflow_conns)

if not vars_valid or not conns_valid:
raise AirflowNotFoundException("Required variables or connections are missing")


def check_variables(*variables):
"""Checks whether all given airflow variables exist.

:param variables: name of airflow variable
:return: True if all variables are valid
"""
is_valid = True
for name in variables:
try:
Variable.get(name)
except AirflowNotFoundException:
logging.error(f"Airflow variable '{name}' not set.")
is_valid = False
return is_valid


def check_connections(*connections):
"""Checks whether all given airflow connections exist.

:param connections: name of airflow connection
:return: True if all connections are valid
"""
is_valid = True
for name in connections:
try:
BaseHook.get_connection(name)
except airflow.exceptions.AirflowNotFoundException:
logging.error(f"Airflow connection '{name}' not set.")
is_valid = False
return is_valid
Loading
Loading