diff --git a/observatory-platform/observatory/platform/dags/load_dags.py b/observatory-platform/observatory/platform/dags/load_dags.py new file mode 100644 index 000000000..587569244 --- /dev/null +++ b/observatory-platform/observatory/platform/dags/load_dags.py @@ -0,0 +1,20 @@ +# Copyright 2023 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: +# https://airflow.apache.org/docs/stable/faq.html + +from observatory.platform.refactor.workflow import load_dags_from_config + +load_dags_from_config() diff --git a/observatory-platform/observatory/platform/dags/load_workflows.py b/observatory-platform/observatory/platform/dags/load_dags_legacy.py similarity index 80% rename from observatory-platform/observatory/platform/dags/load_workflows.py rename to observatory-platform/observatory/platform/dags/load_dags_legacy.py index ea0fde168..d146fa9ef 100644 --- a/observatory-platform/observatory/platform/dags/load_workflows.py +++ b/observatory-platform/observatory/platform/dags/load_dags_legacy.py @@ -22,13 +22,15 @@ from observatory.platform.airflow import fetch_workflows, make_workflow from observatory.platform.observatory_config import Workflow +from observatory.platform.workflows.workflow import Workflow as ObservatoryWorkflow # Load DAGs workflows: List[Workflow] = fetch_workflows() for config in workflows: logging.info(f"Making Workflow: {config.name}, dag_id={config.dag_id}") workflow = make_workflow(config) - dag = workflow.make_dag() - logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}") - globals()[workflow.dag_id] = dag + if isinstance(workflow, ObservatoryWorkflow): + dag = workflow.make_dag() + logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}") + globals()[workflow.dag_id] = dag diff --git a/observatory-platform/observatory/platform/observatory_environment.py b/observatory-platform/observatory/platform/observatory_environment.py index 2477b4e61..6936020bc 100644 --- a/observatory-platform/observatory/platform/observatory_environment.py +++ b/observatory-platform/observatory/platform/observatory_environment.py @@ -436,10 +436,11 @@ def add_connection(self, conn: Connection): self.session.add(conn) self.session.commit() - def run_task(self, task_id: str) -> TaskInstance: + def run_task(self, task_id: str, map_index: int = -1) -> TaskInstance: """Run an Airflow task. :param task_id: the Airflow task identifier. + :param map_index: the map index if the task is a daynamic task :return: None. """ @@ -448,9 +449,29 @@ def run_task(self, task_id: str) -> TaskInstance: dag = self.dag_run.dag run_id = self.dag_run.run_id task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=run_id) + ti = TaskInstance(task, run_id=run_id, map_index=map_index) + ti.refresh_from_db() + + # TODO: remove this when this issue fixed / PR merged: https://github.com/apache/airflow/issues/34023#issuecomment-1705761692 + # https://github.com/apache/airflow/pull/36462 + ignore_task_deps = False + if map_index > -1: + ignore_task_deps = True + + ti.run(ignore_task_deps=ignore_task_deps) + + return ti + + def skip_task(self, task_id: str, map_index: int = -1) -> TaskInstance: + + assert self.dag_run is not None, "with create_dag_run must be called before run_task" + + dag = self.dag_run.dag + run_id = self.dag_run.run_id + task = dag.get_task(task_id=task_id) + ti = TaskInstance(task, run_id=run_id, map_index=map_index) ti.refresh_from_db() - ti.run(ignore_ti_state=True) + ti.set_state(State.SKIPPED) return ti @@ -793,12 +814,16 @@ def assert_dag_structure(self, expected: Dict, dag: DAG): expected_keys = expected.keys() actual_keys = dag.task_dict.keys() + diff = set(expected_keys) - set(actual_keys) self.assertEqual(expected_keys, actual_keys) for task_id, downstream_list in expected.items(): + print(task_id) self.assertTrue(dag.has_task(task_id)) task = dag.get_task(task_id) - self.assertEqual(set(downstream_list), task.downstream_task_ids) + expected = set(downstream_list) + actual = task.downstream_task_ids + self.assertEqual(expected, actual) def assert_dag_load(self, dag_id: str, dag_file: str): """Assert that the given DAG loads from a DagBag. @@ -814,7 +839,7 @@ def assert_dag_load(self, dag_id: str, dag_file: str): shutil.copy(dag_file, os.path.join(dag_folder, os.path.basename(dag_file))) - dag_bag = DagBag(dag_folder=dag_folder) + dag_bag = DagBag(dag_folder=dag_folder, include_examples=False) if dag_bag.import_errors != {}: logging.error(f"DagBag errors: {dag_bag.import_errors}") @@ -837,7 +862,7 @@ def assert_dag_load_from_config(self, dag_id: str): :return: None. """ - self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_workflows.py")) + self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_dags_legacy.py")) def assert_blob_exists(self, bucket_id: str, blob_name: str): """Assert whether a blob exists or not. @@ -962,13 +987,13 @@ def assert_file_integrity(self, file_path: str, expected_hash: str, algorithm: s self.assertEqual(expected_hash, actual_hash) def assert_cleanup(self, workflow_folder: str): - """Assert that the download, extracted and transformed folders were cleaned up. + """Assert that the files in the workflow_folder folder was cleaned up. :param workflow_folder: the path to the DAGs download folder. :return: None. """ - self.assertFalse(os.path.exists(workflow_folder)) + self.assertTrue(len(os.listdir(workflow_folder)) == 0) def setup_mock_file_download( self, uri: str, file_path: str, headers: Dict = None, method: str = httpretty.GET diff --git a/observatory-platform/observatory/platform/refactor/__init__.py b/observatory-platform/observatory/platform/refactor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/observatory-platform/observatory/platform/refactor/sensors.py b/observatory-platform/observatory/platform/refactor/sensors.py new file mode 100644 index 000000000..d858b85e6 --- /dev/null +++ b/observatory-platform/observatory/platform/refactor/sensors.py @@ -0,0 +1,105 @@ +# Copyright 2020, 2021 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Tuan Chien, Keegan Smith, Jamie Diprose + +from __future__ import annotations + +from datetime import timedelta +from functools import partial +from typing import Callable, List, Optional + +import pendulum +from airflow.models import DagRun +from airflow.sensors.external_task import ExternalTaskSensor +from airflow.utils.db import provide_session +from sqlalchemy.orm.scoping import scoped_session + + +class DagCompleteSensor(ExternalTaskSensor): + """ + A sensor that awaits the completion of an external dag by default. Wait functionality can be customised by + providing a different execution_date_fn. + + The sensor checks for completion of a dag with "external_dag_id" on the logical date returned by the + execution_date_fn. + """ + + def __init__( + self, + task_id: str, + external_dag_id: str, + mode: str = "reschedule", + poke_interval: int = 1200, # Check if dag run is ready every 20 minutes + timeout: int = int(timedelta(days=1).total_seconds()), # Sensor will fail after 1 day of waiting + check_existence: bool = True, + execution_date_fn: Optional[Callable] = None, + **kwargs, + ): + """ + :param task_id: the id of the sensor task to create + :param external_dag_id: the id of the external dag to check + :param mode: The mode of the scheduler. Can be reschedule or poke. + :param poke_interval: how often to check if the external dag run is complete + :param timeout: how long to check before the sensor fails + :param check_existence: whether to check that the provided dag_id exists + :param execution_date_fn: a function that returns the logical date(s) of the external DAG runs to query for, + since you need a logical date and a DAG ID to find a particular DAG run to wait for. + """ + + if execution_date_fn is None: + execution_date_fn = partial(get_logical_dates, external_dag_id) + + super().__init__( + task_id=task_id, + external_dag_id=external_dag_id, + mode=mode, + poke_interval=poke_interval, + timeout=timeout, + check_existence=check_existence, + execution_date_fn=execution_date_fn, + **kwargs, + ) + + +@provide_session +def get_logical_dates( + external_dag_id: str, logical_date: pendulum.DateTime, session: scoped_session = None, **context +) -> List[pendulum.DateTime]: + """Get the logical dates for a given external dag that fall between and returns its data_interval_start (logical date) + + :param external_dag_id: the DAG ID of the external DAG we are waiting for. + :param logical_date: the logic date of the waiting DAG. + :param session: the SQL Alchemy session. + :param context: the Airflow context. + :return: the last logical date of the external DAG that falls before the data interval end of the waiting DAG. + """ + + data_interval_end = context["data_interval_end"] + dag_runs = ( + session.query(DagRun) + .filter( + DagRun.dag_id == external_dag_id, + DagRun.data_interval_end <= data_interval_end, + ) + .all() + ) + dates = [d.logical_date for d in dag_runs] + dates.sort(reverse=True) + + # If more than 1 date return first date + if len(dates) >= 2: + dates = [dates[0]] + + return dates diff --git a/observatory-platform/observatory/platform/refactor/tasks.py b/observatory-platform/observatory/platform/refactor/tasks.py new file mode 100644 index 000000000..e93800888 --- /dev/null +++ b/observatory-platform/observatory/platform/refactor/tasks.py @@ -0,0 +1,77 @@ +# Copyright 2020-2023 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import List, Optional + +import airflow +from airflow.decorators import task +from airflow.exceptions import AirflowNotFoundException +from airflow.hooks.base import BaseHook +from airflow.models import Variable + + +@task +def check_dependencies(airflow_vars: Optional[List[str]] = None, airflow_conns: Optional[List[str]] = None, **context): + """Checks if the given Airflow Variables and Connections exist. + + :param airflow_vars: the Airflow Variables to check exist. + :param airflow_conns: the Airflow Connections to check exist. + :return: None. + """ + + vars_valid = True + conns_valid = True + if airflow_vars: + vars_valid = check_variables(*airflow_vars) + if airflow_conns: + conns_valid = check_connections(*airflow_conns) + + if not vars_valid or not conns_valid: + raise AirflowNotFoundException("Required variables or connections are missing") + + +def check_variables(*variables): + """Checks whether all given airflow variables exist. + + :param variables: name of airflow variable + :return: True if all variables are valid + """ + is_valid = True + for name in variables: + try: + Variable.get(name) + except AirflowNotFoundException: + logging.error(f"Airflow variable '{name}' not set.") + is_valid = False + return is_valid + + +def check_connections(*connections): + """Checks whether all given airflow connections exist. + + :param connections: name of airflow connection + :return: True if all connections are valid + """ + is_valid = True + for name in connections: + try: + BaseHook.get_connection(name) + except airflow.exceptions.AirflowNotFoundException: + logging.error(f"Airflow connection '{name}' not set.") + is_valid = False + return is_valid diff --git a/observatory-platform/observatory/platform/refactor/workflow.py b/observatory-platform/observatory/platform/refactor/workflow.py new file mode 100644 index 000000000..05327a886 --- /dev/null +++ b/observatory-platform/observatory/platform/refactor/workflow.py @@ -0,0 +1,383 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: James Diprose, Aniek Roelofs, Tuan Chien + + +from __future__ import annotations + +import json +import logging +import os +import shutil +from dataclasses import dataclass, field +from pydoc import locate +from typing import Any, Dict, List, Optional + +import pendulum +from airflow import AirflowException +from airflow.models import DAG, DagBag, Variable + +from observatory.platform.airflow import delete_old_xcoms +from observatory.platform.config import AirflowVars + + +def get_data_path() -> str: + """Grabs the DATA_PATH airflow vairable + + :raises AirflowException: Raised if the variable does not exist + :return: DATA_PATH variable contents + """ + + # Try to get environment variable from environment variable first + data_path = os.environ.get(AirflowVars.DATA_PATH) + if data_path is not None: + return data_path + + # Try to get from Airflow Variable + data_path = Variable.get(AirflowVars.DATA_PATH) + if data_path is not None: + return data_path + + raise AirflowException("DATA_PATH variable could not be found.") + + +def fetch_workflows() -> List[Workflow]: + """Fetches Workflow instances from the WORKFLOWS Airflow Variable. + + :return: a list of Workflow instances. + """ + + workflows = [] + workflows_str = Variable.get(AirflowVars.WORKFLOWS) + logging.info(f"workflows_str: {workflows_str}") + + if workflows_str is not None and workflows_str.strip() != "": + try: + workflows = json_string_to_workflows(workflows_str) + logging.info(f"workflows: {workflows}") + except json.decoder.JSONDecodeError as e: + e.msg = f"workflows_str: {workflows_str}\n\n{e.msg}" + + return workflows + + +def load_dags_from_config(): + """Loads DAGs from a workflow config file, stored in the WORKFLOWS Airflow Variable. + + :return: None. + """ + + for workflow in fetch_workflows(): + dag_id = workflow.dag_id + logging.info(f"Making Workflow: {workflow.name}, dag_id={dag_id}") + dag = make_dag(workflow) + + if isinstance(dag, DAG): + logging.info(f"Adding DAG: dag_id={dag_id}, dag={dag}") + globals()[dag_id] = dag + + +def make_dag(workflow: Workflow): + """Make a DAG instance from a Workflow config. + :param workflow: the workflow configuration. + :return: the workflow instance. + """ + + cls = locate(workflow.class_name) + if cls is None: + raise ModuleNotFoundError(f"dag_id={workflow.dag_id}: could not locate class_name={workflow.class_name}") + + return cls(dag_id=workflow.dag_id, cloud_workspace=workflow.cloud_workspace, **workflow.kwargs) + + +def make_workflow_folder(dag_id: str, run_id: str, *subdirs: str) -> str: + """Return the path to this dag release's workflow folder. Will also create it if it doesn't exist + + :param dag_id: The ID of the dag. This is used to find/create the workflow folder + :param run_id: The Airflow DAGs run ID. Examples: "scheduled__2023-03-26T00:00:00+00:00" or "manual__2023-03-26T00:00:00+00:00". + :param subdirs: The folder path structure (if any) to create inside the workspace. e.g. 'download' or 'transform' + :return: the path of the workflow folder + """ + + path = os.path.join(get_data_path(), dag_id, run_id, *subdirs) + os.makedirs(path, exist_ok=True) + return path + + +def fetch_dag_bag(path: str, include_examples: bool = False) -> DagBag: + """Load a DAG Bag from a given path. + + :param path: the path to the DAG bag. + :param include_examples: whether to include example DAGs or not. + :return: None. + """ + logging.info(f"Loading DAG bag from path: {path}") + dag_bag = DagBag(path, include_examples=include_examples) + + if dag_bag is None: + raise Exception(f"DagBag could not be loaded from path: {path}") + + if len(dag_bag.import_errors): + # Collate loading errors as single string and raise it as exception + results = [] + for path, exception in dag_bag.import_errors.items(): + results.append(f"DAG import exception: {path}\n{exception}\n\n") + raise Exception("\n".join(results)) + + return dag_bag + + +def cleanup(dag_id: str, logical_date: str, workflow_folder: str = None, retention_days=31) -> None: + """Delete all files, folders and XComs associated from a release. + + :param dag_id: The ID of the DAG to remove XComs + :param logical_date: The execution date of the DAG run + :param workflow_folder: The top-level workflow folder to clean up + :param retention_days: How many days of Xcom messages to retain + """ + if workflow_folder: + try: + shutil.rmtree(workflow_folder) + except FileNotFoundError as e: + logging.warning(f"No such file or directory {workflow_folder}: {e}") + + delete_old_xcoms(dag_id=dag_id, logical_date=logical_date, retention_days=retention_days) + + +class CloudWorkspace: + def __init__( + self, + *, + project_id: str, + download_bucket: str, + transform_bucket: str, + data_location: str, + output_project_id: Optional[str] = None, + ): + """The CloudWorkspace settings used by workflows. + + project_id: the Google Cloud project id. input_project_id is an alias for project_id. + download_bucket: the Google Cloud Storage bucket where downloads will be stored. + transform_bucket: the Google Cloud Storage bucket where transformed data will be stored. + data_location: the data location for storing information, e.g. where BigQuery datasets should be located. + output_project_id: an optional Google Cloud project id when the outputs of a workflow should be stored in a + different project to the inputs. If an output_project_id is not supplied, the project_id will be used. + """ + + self._project_id = project_id + self._download_bucket = download_bucket + self._transform_bucket = transform_bucket + self._data_location = data_location + self._output_project_id = output_project_id + + @property + def project_id(self) -> str: + return self._project_id + + @project_id.setter + def project_id(self, project_id: str): + self._project_id = project_id + + @property + def download_bucket(self) -> str: + return self._download_bucket + + @download_bucket.setter + def download_bucket(self, download_bucket: str): + self._download_bucket = download_bucket + + @property + def transform_bucket(self) -> str: + return self._transform_bucket + + @transform_bucket.setter + def transform_bucket(self, transform_bucket: str): + self._transform_bucket = transform_bucket + + @property + def data_location(self) -> str: + return self._data_location + + @data_location.setter + def data_location(self, data_location: str): + self._data_location = data_location + + @property + def input_project_id(self) -> str: + return self._project_id + + @input_project_id.setter + def input_project_id(self, project_id: str): + self._project_id = project_id + + @property + def output_project_id(self) -> Optional[str]: + if self._output_project_id is None: + return self._project_id + return self._output_project_id + + @output_project_id.setter + def output_project_id(self, output_project_id: Optional[str]): + self._output_project_id = output_project_id + + @staticmethod + def from_dict(dict_: Dict) -> CloudWorkspace: + """Constructs a CloudWorkspace instance from a dictionary. + + :param dict_: the dictionary. + :return: the Workflow instance. + """ + + project_id = dict_.get("project_id") + download_bucket = dict_.get("download_bucket") + transform_bucket = dict_.get("transform_bucket") + data_location = dict_.get("data_location") + output_project_id = dict_.get("output_project_id") + + return CloudWorkspace( + project_id=project_id, + download_bucket=download_bucket, + transform_bucket=transform_bucket, + data_location=data_location, + output_project_id=output_project_id, + ) + + def to_dict(self) -> Dict: + """CloudWorkspace instance to dictionary. + + :return: the dictionary. + """ + + return dict( + project_id=self._project_id, + download_bucket=self._download_bucket, + transform_bucket=self._transform_bucket, + data_location=self._data_location, + output_project_id=self.output_project_id, + ) + + @staticmethod + def parse_cloud_workspaces(list: List) -> List[CloudWorkspace]: + """Parse the cloud workspaces list object into a list of CloudWorkspace instances. + + :param list: the list. + :return: a list of CloudWorkspace instances. + """ + + return [CloudWorkspace.from_dict(dict_) for dict_ in list] + + +@dataclass +class Workflow: + """A Workflow configuration. + + Attributes: + dag_id: the Airflow DAG identifier for the workflow. + name: a user-friendly name for the workflow. + class_name: the fully qualified class name for the workflow class. + cloud_workspace: the Cloud Workspace to use when running the workflow. + kwargs: a dictionary containing optional keyword arguments that are injected into the workflow constructor. + """ + + dag_id: str = None + name: str = None + class_name: str = None + cloud_workspace: CloudWorkspace = None + kwargs: Optional[Dict] = field(default_factory=lambda: dict()) + + def to_dict(self) -> Dict: + """Workflow instance to dictionary. + + :return: the dictionary. + """ + + cloud_workspace = self.cloud_workspace + if self.cloud_workspace is not None: + cloud_workspace = self.cloud_workspace.to_dict() + + return dict( + dag_id=self.dag_id, + name=self.name, + class_name=self.class_name, + cloud_workspace=cloud_workspace, + kwargs=self.kwargs, + ) + + @staticmethod + def from_dict(dict_: Dict) -> Workflow: + """Constructs a Workflow instance from a dictionary. + + :param dict_: the dictionary. + :return: the Workflow instance. + """ + + dag_id = dict_.get("dag_id") + name = dict_.get("name") + class_name = dict_.get("class_name") + + cloud_workspace = dict_.get("cloud_workspace") + if cloud_workspace is not None: + cloud_workspace = CloudWorkspace.from_dict(cloud_workspace) + + kwargs = dict_.get("kwargs", dict()) + + return Workflow(dag_id, name, class_name, cloud_workspace, kwargs) + + @staticmethod + def parse_workflows(list: List) -> List[Workflow]: + """Parse the workflows list object into a list of Workflow instances. + + :param list: the list. + :return: a list of Workflow instances. + """ + + return [Workflow.from_dict(dict_) for dict_ in list] + + +class PendulumDateTimeEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, pendulum.DateTime): + return obj.isoformat() + return super().default(obj) + + +def workflows_to_json_string(workflows: List[Workflow]) -> str: + """Covnert a list of Workflow instances to a JSON string. + + :param workflows: the Workflow instances. + :return: a JSON string. + """ + + data = [workflow.to_dict() for workflow in workflows] + return json.dumps(data, cls=PendulumDateTimeEncoder) + + +def json_string_to_workflows(json_string: str) -> List[Workflow]: + """Convert a JSON string into a list of Workflow instances. + + :param json_string: a JSON string version of a list of Workflow instances. + :return: a list of Workflow instances. + """ + + def parse_datetime(obj): + for key, value in obj.items(): + try: + obj[key] = pendulum.parse(value) + except (ValueError, TypeError): + pass + return obj + + data = json.loads(json_string, object_hook=parse_datetime) + return Workflow.parse_workflows(data) diff --git a/observatory-platform/observatory/platform/workflows/vm_workflow.py b/observatory-platform/observatory/platform/workflows/vm_workflow.py index d82c17016..3bbde2560 100644 --- a/observatory-platform/observatory/platform/workflows/vm_workflow.py +++ b/observatory-platform/observatory/platform/workflows/vm_workflow.py @@ -16,7 +16,7 @@ import logging from datetime import datetime -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import pendulum from airflow.models.dag import DAG @@ -24,6 +24,7 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.utils.state import DagRunState +from airflow.utils.trigger_rule import TriggerRule from croniter import croniter from observatory.platform.airflow import delete_old_xcoms, get_airflow_connection_password, send_slack_msg @@ -220,7 +221,7 @@ def __init__( self.add_task(self.update_terraform_variable) self.add_task(self.run_terraform) self.add_task(self.check_run_status) - self.add_task(self.cleanup, trigger_rule="none_failed") + self.add_task(self.cleanup) def make_release(self, **kwargs) -> None: """Required for Workflow class. diff --git a/observatory-platform/observatory/platform/workflows/workflow.py b/observatory-platform/observatory/platform/workflows/workflow.py index e7e3510aa..d00994e15 100644 --- a/observatory-platform/observatory/platform/workflows/workflow.py +++ b/observatory-platform/observatory/platform/workflows/workflow.py @@ -22,7 +22,7 @@ import shutil from abc import ABC, abstractmethod from functools import partial -from typing import Any, Callable, Dict, List, Union, Optional +from typing import Any, Callable, Dict, List, Optional, Union try: from typing import Protocol @@ -151,7 +151,57 @@ def __init__(self, *, dag_id: str, run_id: str): self.dag_id = dag_id self.run_id = run_id - self.workflow_folder = make_workflow_folder(self.dag_id, run_id) + + @property + def workflow_folder(self): + """Get the path to the workflow folder, namespaced to a DAG run. Can contain multiple release folders. + + :return: path to folder. + """ + + return make_workflow_folder(self.dag_id, self.run_id) + + @property + def release_folder(self): + """Get the path to the release folder, which resides inside the workflow folder. + + :return: path to folder. + """ + + raise NotImplementedError("self.release_folder should be implemented by subclasses") + + @property + def download_folder(self): + """Get the path to the download folder, which contains downloaded files. Resides in a release folder. + + :return: path to folder. + """ + + path = os.path.join(self.release_folder, "download") + os.makedirs(path, exist_ok=True) + return path + + @property + def extract_folder(self): + """Get the path to the extract folder, which contains extracted files. Resides in a release folder. + + :return: path to folder. + """ + + path = os.path.join(self.release_folder, "extract") + os.makedirs(path, exist_ok=True) + return path + + @property + def transform_folder(self): + """Get the path to the transform folder, which contains transformed files. Resides in a release folder. + + :return: path to folder. + """ + + path = os.path.join(self.release_folder, "transform") + os.makedirs(path, exist_ok=True) + return path def __str__(self): return f"Release(dag_id={self.dag_id}, run_id={self.run_id})" @@ -175,10 +225,14 @@ def __init__( super().__init__(dag_id=dag_id, run_id=run_id) self.snapshot_date = snapshot_date - snapshot = f"snapshot_{snapshot_date.format(DATE_TIME_FORMAT)}" - self.download_folder = make_workflow_folder(self.dag_id, run_id, snapshot, "download") - self.extract_folder = make_workflow_folder(self.dag_id, run_id, snapshot, "extract") - self.transform_folder = make_workflow_folder(self.dag_id, run_id, snapshot, "transform") + @property + def release_folder(self): + """Get the path to the release folder, which resides inside the workflow folder. + + :return: path to folder. + """ + + return make_workflow_folder(self.dag_id, self.run_id, f"snapshot_{self.snapshot_date.format(DATE_TIME_FORMAT)}") def __str__(self): return f"SnapshotRelease(dag_id={self.dag_id}, run_id={self.run_id}, snapshot_date={self.snapshot_date})" @@ -202,10 +256,16 @@ def __init__( super().__init__(dag_id=dag_id, run_id=run_id) self.partition_date = partition_date - partition = f"partition_{partition_date.format(DATE_TIME_FORMAT)}" - self.download_folder = make_workflow_folder(self.dag_id, run_id, partition, "download") - self.extract_folder = make_workflow_folder(self.dag_id, run_id, partition, "extract") - self.transform_folder = make_workflow_folder(self.dag_id, run_id, partition, "transform") + @property + def release_folder(self): + """Get the path to the release folder, which resides inside the workflow folder. + + :return: path to folder. + """ + + return make_workflow_folder( + self.dag_id, self.run_id, f"partition_{self.partition_date.format(DATE_TIME_FORMAT)}" + ) def __str__(self): return f"PartitionRelease(dag_id={self.dag_id}, run_id={self.run_id}, partition_date={self.partition_date})" @@ -238,10 +298,23 @@ def __init__( self.sequence_start = sequence_start self.sequence_end = sequence_end - changefile = f"changefile_{start_date.format(DATE_TIME_FORMAT)}_to_{end_date.format(DATE_TIME_FORMAT)}" - self.download_folder = make_workflow_folder(self.dag_id, run_id, changefile, "download") - self.extract_folder = make_workflow_folder(self.dag_id, run_id, changefile, "extract") - self.transform_folder = make_workflow_folder(self.dag_id, run_id, changefile, "transform") + def __eq__(self, other): + if isinstance(other, ChangefileRelease): + return self.__dict__ == other.__dict__ + return False + + @property + def release_folder(self): + """Get the path to the release folder, which resides inside the workflow folder. + + :return: path to folder. + """ + + return make_workflow_folder( + self.dag_id, + self.run_id, + f"changefile_{self.start_date.format(DATE_TIME_FORMAT)}_to_{self.end_date.format(DATE_TIME_FORMAT)}", + ) def __str__(self): return ( diff --git a/tests/observatory/platform/test_observatory_environment.py b/tests/observatory/platform/test_observatory_environment.py index 274568ddf..143eb3ffe 100644 --- a/tests/observatory/platform/test_observatory_environment.py +++ b/tests/observatory/platform/test_observatory_environment.py @@ -617,21 +617,22 @@ def test_assert_cleanup(self): """Test assert_cleanup""" with CliRunner().isolated_filesystem() as temp_dir: - workflow = os.path.join(temp_dir, "workflow") + workflow_folder = os.path.join(temp_dir, "workflow") + test_folder = os.path.join(workflow_folder, "test") # Make download, extract and transform folders - os.makedirs(workflow) + os.makedirs(test_folder) # Check that assertion is raised when folders exist test_case = ObservatoryTestCase() with self.assertRaises(AssertionError): - test_case.assert_cleanup(workflow) + test_case.assert_cleanup(workflow_folder) # Delete folders - os.rmdir(workflow) + os.rmdir(test_folder) # No error when folders deleted - test_case.assert_cleanup(workflow) + test_case.assert_cleanup(workflow_folder) def test_setup_mock_file_download(self): """Test mocking a file download""" diff --git a/tests/observatory/platform/workflows/test_vm_create.py b/tests/observatory/platform/workflows/test_vm_create.py index 0512d2938..1bad0d0b9 100644 --- a/tests/observatory/platform/workflows/test_vm_create.py +++ b/tests/observatory/platform/workflows/test_vm_create.py @@ -163,27 +163,27 @@ def test_workflow_vm_already_on(self, m_tapi, m_list_workspace_vars): with time_machine.travel(dag_run.start_date, tick=True): # check dependencies ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) + self.assertEqual(State.SUCCESS, ti.state) # check vm state ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) + self.assertEqual(State.SUCCESS, ti.state) # update terraform variable ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SKIPPED) + self.assertEqual(State.SKIPPED, ti.state) # run terraform ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SKIPPED) + self.assertEqual(State.SKIPPED, ti.state) # check run status ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SKIPPED) + self.assertEqual(State.SKIPPED, ti.state) # cleanup ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) + self.assertEqual(State.SKIPPED, ti.state) @patch("observatory.platform.workflows.vm_workflow.send_slack_msg") @patch("observatory.platform.workflows.vm_workflow.TerraformApi.get_run_details")