diff --git a/.astro-registry.yaml b/.astro-registry.yaml index caf5cc0..2334373 100644 --- a/.astro-registry.yaml +++ b/.astro-registry.yaml @@ -12,4 +12,4 @@ operators: triggers: - module: anyscale_provider.triggers.anyscale.AnyscaleJobTrigger - - module: anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger \ No newline at end of file + - module: anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger diff --git a/.coveragerc b/.coveragerc index 04b8342..b24d4a4 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,3 @@ [run] omit = - tests/* \ No newline at end of file + tests/* diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 37bec17..a9b4568 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -20,20 +20,18 @@ jobs: steps: - run: true - Type-Check: + Static-Check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha || github.ref }} - - uses: actions/setup-python@v4 with: - python-version: "3.9" + python-version: "3.11" architecture: "x64" - - run: pip3 install hatch - - run: hatch run tests.py3.9-2.7:type-check + - run: hatch run tests.py3.9-2.7:static-check Run-Unit-Tests: runs-on: ubuntu-latest @@ -133,4 +131,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: coverage-integration-test-${{ matrix.python-version }}-${{ matrix.airflow-version }} - path: .coverage \ No newline at end of file + path: .coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4ed323a..0b2e7de 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,7 @@ repos: name: Run codespell to check for common misspellings in files language: python types: [text] + args: ["--ignore-words", codespell-ignore-words.txt] - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: @@ -88,4 +89,4 @@ ci: autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate skip: - mypy # build of https://github.com/pre-commit/mirrors-mypy:types-PyYAML,types-attrs,attrs,types-requests, - #types-python-dateutil,apache-airflow@v1.5.0 for python@python3 exceeds tier max size 250MiB: 262.6MiB \ No newline at end of file + #types-python-dateutil,apache-airflow@v1.5.0 for python@python3 exceeds tier max size 250MiB: 262.6MiB diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8ef80e1..79c8f5d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -55,4 +55,4 @@ CHANGELOG - .. code-block:: python from anyscale_provider.triggers.anyscale import AnyscaleServiceTrigger - - N/A \ No newline at end of file + - N/A diff --git a/CODEOWNERS b/CODEOWNERS index ea9e2b6..626fcef 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -* @venkatajagannath \ No newline at end of file +* @venkatajagannath diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 40cf8a2..78ae3a8 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -163,4 +163,4 @@ All tests are inside ``./tests`` directory. - Just run ``pytest filepath+filename`` to run the tests. -For more information, please see the contributing guide available `here `_ \ No newline at end of file +For more information, please see the contributing guide available `here `_ diff --git a/README.md b/README.md index 52ac8b0..7f9cb3e 100644 --- a/README.md +++ b/README.md @@ -46,39 +46,39 @@ from pathlib import Path from anyscale_provider.operators.anyscale import SubmitAnyscaleJob default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': datetime(2024, 4, 2), - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5), + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2024, 4, 2), + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), } # Define the Anyscale connection ANYSCALE_CONN_ID = "anyscale_conn" # Constants -FOLDER_PATH = Path(__file__).parent /"ray_scripts" +FOLDER_PATH = Path(__file__).parent / "ray_scripts" dag = DAG( - 'sample_anyscale_workflow', + "sample_anyscale_workflow", default_args=default_args, - description='A DAG to interact with Anyscale triggered manually', + description="A DAG to interact with Anyscale triggered manually", schedule_interval=None, # This DAG is not scheduled, only triggered manually catchup=False, ) submit_anyscale_job = SubmitAnyscaleJob( - task_id='submit_anyscale_job', - conn_id = ANYSCALE_CONN_ID, - name = 'AstroJob', - image_uri = 'anyscale/ray:2.23.0-py311', - compute_config = 'my-compute-config:1', - working_dir = str(FOLDER_PATH), - entrypoint= 'python script.py', - requirements = ["requests","pandas","numpy","torch"], - max_retries = 1, + task_id="submit_anyscale_job", + conn_id=ANYSCALE_CONN_ID, + name="AstroJob", + image_uri="anyscale/ray:2.23.0-py311", + compute_config="my-compute-config:1", + working_dir=str(FOLDER_PATH), + entrypoint="python script.py", + requirements=["requests", "pandas", "numpy", "torch"], + max_retries=1, dag=dag, ) @@ -96,22 +96,22 @@ from anyscale_provider.operators.anyscale import RolloutAnyscaleService from anyscale_provider.hooks.anyscale import AnyscaleHook default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': datetime(2024, 4, 2), - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5), + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2024, 4, 2), + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), } # Define the Anyscale connection ANYSCALE_CONN_ID = "anyscale_conn" dag = DAG( - 'sample_anyscale_service_workflow', + "sample_anyscale_service_workflow", default_args=default_args, - description='A DAG to interact with Anyscale triggered manually', + description="A DAG to interact with Anyscale triggered manually", schedule_interval=None, # This DAG is not scheduled, only triggered manually catchup=False, ) @@ -120,24 +120,25 @@ deploy_anyscale_service = RolloutAnyscaleService( task_id="rollout_anyscale_service", conn_id=ANYSCALE_CONN_ID, name="AstroService", - image_uri='anyscale/ray:2.23.0-py311', - compute_config='my-compute-config:1', + image_uri="anyscale/ray:2.23.0-py311", + compute_config="my-compute-config:1", working_dir="https://github.com/anyscale/docs_examples/archive/refs/heads/main.zip", applications=[{"import_path": "sentiment_analysis.app:model"}], requirements=["transformers", "requests", "pandas", "numpy", "torch"], in_place=False, canary_percent=None, - dag=dag + dag=dag, ) + def terminate_service(): hook = AnyscaleHook(conn_id=ANYSCALE_CONN_ID) - result = hook.terminate_service(service_id="AstroService", - time_delay=5) + result = hook.terminate_service(service_id="AstroService", time_delay=5) print(result) + terminate_anyscale_service = PythonOperator( - task_id='initialize_anyscale_hook', + task_id="initialize_anyscale_hook", python_callable=terminate_service, dag=dag, ) @@ -159,4 +160,4 @@ __________________ All contributions, bug reports, bug fixes, documentation improvements, enhancements are welcome. -A detailed overview an how to contribute can be found in the [Contributing Guide](https://github.com/astronomer/astro-provider-anyscale/blob/main/CONTRIBUTING.rst) \ No newline at end of file +A detailed overview an how to contribute can be found in the [Contributing Guide](https://github.com/astronomer/astro-provider-anyscale/blob/main/CONTRIBUTING.rst) diff --git a/anyscale_provider/__init__.py b/anyscale_provider/__init__.py index c72e5d6..93423b0 100644 --- a/anyscale_provider/__init__.py +++ b/anyscale_provider/__init__.py @@ -1,8 +1,9 @@ __version__ = "1.0.0" -from typing import Any, Dict, Optional +from typing import Any, Dict -def get_provider_info() -> Dict[str,Any]: + +def get_provider_info() -> Dict[str, Any]: return { "package-name": "astro-provider-anyscale", # Required "name": "Anyscale", # Required @@ -11,4 +12,4 @@ def get_provider_info() -> Dict[str,Any]: {"connection-type": "anyscale", "hook-class-name": "anyscale_provider.hooks.anyscale.AnyscaleHook"} ], "versions": [__version__], # Required - } \ No newline at end of file + } diff --git a/anyscale_provider/hooks/anyscale.py b/anyscale_provider/hooks/anyscale.py index f8388ba..baebdc9 100644 --- a/anyscale_provider/hooks/anyscale.py +++ b/anyscale_provider/hooks/anyscale.py @@ -1,17 +1,13 @@ import os import time -import logging from typing import Any, Dict, Optional -import anyscale +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook # Adjusted import based on Airflow's newer version from anyscale import Anyscale from anyscale.job.models import JobConfig, JobStatus -from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceState +from anyscale.service.models import ServiceConfig, ServiceStatus -from airflow.hooks.base import BaseHook # Adjusted import based on Airflow's newer version -from airflow.exceptions import AirflowException -from airflow.compat.functools import cached_property -from anyscale.sdk.anyscale_client.models import * class AnyscaleHook(BaseHook): """ @@ -116,13 +112,12 @@ def __init__(self, conn_id: str = default_conn_name, **kwargs: Any) -> None: # If the token is not found in the connection, try to get it from the environment variable if not token: token = os.getenv("ANYSCALE_CLI_TOKEN") - + if not token: raise AirflowException(f"Missing API token for connection id {self.conn_id}") self.sdk = Anyscale(auth_token=token) - @classmethod def get_ui_field_behaviour(cls) -> Dict[str, Any]: """Return custom field behaviour for the connection form in the UI.""" @@ -133,23 +128,25 @@ def get_ui_field_behaviour(cls) -> Dict[str, Any]: } def submit_job(self, config: JobConfig) -> str: - self.log.info("Creating a job with configuration: {}".format(config)) + self.log.info(f"Creating a job with configuration: {config}") job_id: str = self.sdk.job.submit(config=config) return job_id - def deploy_service(self, config: ServiceConfig, - in_place: bool = False, - canary_percent: Optional[float] = None, - max_surge_percent: Optional[float] = None) -> str: - self.log.info("Deploying a service with configuration: {}".format(config)) - service_id: str = self.sdk.service.deploy(config=config, - in_place=in_place, - canary_percent=canary_percent, - max_surge_percent=max_surge_percent) + def deploy_service( + self, + config: ServiceConfig, + in_place: bool = False, + canary_percent: Optional[float] = None, + max_surge_percent: Optional[float] = None, + ) -> str: + self.log.info(f"Deploying a service with configuration: {config}") + service_id: str = self.sdk.service.deploy( + config=config, in_place=in_place, canary_percent=canary_percent, max_surge_percent=max_surge_percent + ) return service_id def get_job_status(self, job_id: str) -> JobStatus: - self.log.info("Fetching job status for Job name: {}".format(job_id)) + self.log.info(f"Fetching job status for Job name: {job_id}") return self.sdk.job.status(job_id=job_id) def get_service_status(self, service_name: str) -> ServiceStatus: @@ -177,4 +174,4 @@ def terminate_service(self, service_id: str, time_delay: int) -> bool: def get_logs(self, job_id: str) -> str: logs: str = self.sdk.job.get_logs(job_id=job_id) - return logs \ No newline at end of file + return logs diff --git a/anyscale_provider/operators/anyscale.py b/anyscale_provider/operators/anyscale.py index 8c415c4..97330e0 100644 --- a/anyscale_provider/operators/anyscale.py +++ b/anyscale_provider/operators/anyscale.py @@ -1,28 +1,23 @@ # Standard library imports -import logging -import os import time -from typing import List, Dict, Union, Any, Optional +from typing import Any, Dict, List, Optional, Union # Third-party imports import anyscale -from anyscale.job.models import JobState -from anyscale.compute_config.models import ( - ComputeConfig, HeadNodeConfig, MarketType, WorkerNodeGroupConfig -) -from anyscale.job.models import JobConfig -from anyscale.service.models import ServiceConfig, RayGCSExternalStorageConfig, ServiceState # Airflow imports from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.context import Context -from airflow.triggers.base import TriggerEvent -from airflow.utils.decorators import apply_defaults +from anyscale.compute_config.models import ComputeConfig +from anyscale.job.models import JobConfig, JobState +from anyscale.service.models import RayGCSExternalStorageConfig, ServiceConfig, ServiceState + from anyscale_provider.hooks.anyscale import AnyscaleHook from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger + class SubmitAnyscaleJob(BaseOperator): """ Submits a job to Anyscale from Apache Airflow. @@ -49,21 +44,24 @@ class SubmitAnyscaleJob(BaseOperator): :raises AirflowException: If job name or entrypoint is not provided. """ - - def __init__(self, - conn_id: str, - name: str, - image_uri: str, - compute_config: Union[ComputeConfig, Dict[str, Any], str], - working_dir: str, - entrypoint: str, - excludes: Optional[List[str]] = None, - requirements: Optional[Union[str, List[str]]] = None, - env_vars: Optional[Dict[str, str]] = None, - py_modules: Optional[List[str]] = None, - max_retries: int = 1, - *args: Any, **kwargs: Any) -> None: - super(SubmitAnyscaleJob, self).__init__(*args, **kwargs) + + def __init__( + self, + conn_id: str, + name: str, + image_uri: str, + compute_config: Union[ComputeConfig, Dict[str, Any], str], + working_dir: str, + entrypoint: str, + excludes: Optional[List[str]] = None, + requirements: Optional[Union[str, List[str]]] = None, + env_vars: Optional[Dict[str, str]] = None, + py_modules: Optional[List[str]] = None, + max_retries: int = 1, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) self.conn_id = conn_id self.name = name self.image_uri = image_uri @@ -88,29 +86,29 @@ def __init__(self, "env_vars": env_vars, "py_modules": py_modules, "entrypoint": entrypoint, - "max_retries": max_retries + "max_retries": max_retries, } if not self.name: raise AirflowException("Job name is required.") - + # Ensure entrypoint is not empty if not self.entrypoint: raise AirflowException("Entrypoint must be specified.") - + def on_kill(self) -> None: if self.job_id is not None: self.hook.terminate_job(self.job_id, 5) self.log.info("Termination request received. Submitted request to terminate the anyscale job.") return - + @cached_property def hook(self) -> AnyscaleHook: """Return an instance of the AnyscaleHook.""" return AnyscaleHook(conn_id=self.conn_id) def execute(self, context: Context) -> Optional[str]: - + if not self.hook: self.log.info("SDK is not available.") raise AirflowException("SDK is not available.") @@ -127,9 +125,9 @@ def execute(self, context: Context) -> Optional[str]: self.log.info(f"Current status for {self.job_id} is: {current_status}") self.process_job_status(self.job_id, current_status) - + return self.job_id - + def process_job_status(self, job_id: str, current_status: str) -> None: if current_status in (JobState.STARTING, JobState.RUNNING): self.defer_job_polling(job_id) @@ -139,21 +137,22 @@ def process_job_status(self, job_id: str, current_status: str) -> None: raise AirflowException(f"Job {job_id} failed.") else: raise Exception(f"Unexpected state `{current_status}` for job_id `{job_id}`.") - + def defer_job_polling(self, job_id: str) -> None: self.log.info("Deferring the polling to AnyscaleJobTrigger...") - self.defer(trigger=AnyscaleJobTrigger(conn_id=self.conn_id, - job_id=job_id, - job_start_time=self.created_at, - poll_interval=60), - method_name="execute_complete") + self.defer( + trigger=AnyscaleJobTrigger( + conn_id=self.conn_id, job_id=job_id, job_start_time=self.created_at, poll_interval=60 + ), + method_name="execute_complete", + ) def get_current_status(self, job_id: str) -> str: return str(self.hook.get_job_status(job_id=job_id).state) def execute_complete(self, context: Context, event: Any) -> None: current_job_id = event["job_id"] - + if event["status"] == JobState.FAILED: self.log.info(f"Anyscale job {current_job_id} ended with status: {event['status']}") raise AirflowException(f"Job {current_job_id} failed with error {event['message']}") @@ -161,12 +160,13 @@ def execute_complete(self, context: Context, event: Any) -> None: self.log.info(f"Anyscale job {current_job_id} completed with status: {event['status']}") return None + class RolloutAnyscaleService(BaseOperator): """ Rolls out a service on Anyscale from Apache Airflow. - This operator handles the deployment of services on Anyscale, including the necessary - configurations and options. It ensures the service is rolled out according to the + This operator handles the deployment of services on Anyscale, including the necessary + configurations and options. It ensures the service is rolled out according to the specified parameters and handles the deployment lifecycle. .. seealso:: @@ -197,47 +197,49 @@ class RolloutAnyscaleService(BaseOperator): :raises AirflowException: If the SDK is not available or the service deployment fails. """ - def __init__(self, - conn_id: str, - name: str, - image_uri: str, - compute_config: Union[ComputeConfig, Dict[str, Any], str], - applications: List[Dict[str, Any]], - working_dir: str, - containerfile: Optional[str] = None, - excludes: Optional[List[str]] = None, - requirements: Optional[Union[str, List[str]]] = None, - env_vars: Optional[Dict[str, str]] = None, - py_modules: Optional[List[str]] = None, - query_auth_token_enabled: bool = False, - http_options: Optional[Dict[str, Any]] = None, - grpc_options: Optional[Dict[str, Any]] = None, - logging_config: Optional[Dict[str, Any]] = None, - ray_gcs_external_storage_config: Optional[Union[RayGCSExternalStorageConfig, Dict[str, Any]]] = None, - in_place: bool = False, - canary_percent: Optional[float] = None, - max_surge_percent: Optional[float] = None, - **kwargs: Any) -> None: + def __init__( + self, + conn_id: str, + name: str, + image_uri: str, + compute_config: Union[ComputeConfig, Dict[str, Any], str], + applications: List[Dict[str, Any]], + working_dir: str, + containerfile: Optional[str] = None, + excludes: Optional[List[str]] = None, + requirements: Optional[Union[str, List[str]]] = None, + env_vars: Optional[Dict[str, str]] = None, + py_modules: Optional[List[str]] = None, + query_auth_token_enabled: bool = False, + http_options: Optional[Dict[str, Any]] = None, + grpc_options: Optional[Dict[str, Any]] = None, + logging_config: Optional[Dict[str, Any]] = None, + ray_gcs_external_storage_config: Optional[Union[RayGCSExternalStorageConfig, Dict[str, Any]]] = None, + in_place: bool = False, + canary_percent: Optional[float] = None, + max_surge_percent: Optional[float] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.conn_id = conn_id # Set up explicit parameters self.service_params: Dict[str, Any] = { - 'name': name, - 'image_uri': image_uri, - 'containerfile': containerfile, - 'compute_config': compute_config, - 'working_dir': working_dir, - 'excludes': excludes, - 'requirements': requirements, - 'env_vars': env_vars, - 'py_modules': py_modules, - 'applications': applications, - 'query_auth_token_enabled': query_auth_token_enabled, - 'http_options': http_options, - 'grpc_options': grpc_options, - 'logging_config': logging_config, - 'ray_gcs_external_storage_config': ray_gcs_external_storage_config + "name": name, + "image_uri": image_uri, + "containerfile": containerfile, + "compute_config": compute_config, + "working_dir": working_dir, + "excludes": excludes, + "requirements": requirements, + "env_vars": env_vars, + "py_modules": py_modules, + "applications": applications, + "query_auth_token_enabled": query_auth_token_enabled, + "http_options": http_options, + "grpc_options": grpc_options, + "logging_config": logging_config, + "ray_gcs_external_storage_config": ray_gcs_external_storage_config, } self.in_place = in_place @@ -245,47 +247,53 @@ def __init__(self, self.max_surge_percent = max_surge_percent # Ensure name is not empty - if not self.service_params['name']: + if not self.service_params["name"]: raise ValueError("Service name is required.") - + # Ensure at least one application is specified - if not self.service_params['applications']: + if not self.service_params["applications"]: raise ValueError("At least one application must be specified.") @cached_property def hook(self) -> AnyscaleHook: """Return an instance of the AnyscaleHook.""" return AnyscaleHook(conn_id=self.conn_id) - + def execute(self, context: Context) -> Optional[str]: if not self.hook: self.log.info(f"SDK is not available...") raise AirflowException("SDK is not available") - + svc_config = ServiceConfig(**self.service_params) - self.log.info("Service with config object: {}".format(svc_config)) - + self.log.info(f"Service with config object: {svc_config}") + # Call the SDK method with the dynamically created service model - service_id = self.hook.deploy_service(config=svc_config, - in_place=self.in_place, - canary_percent=self.canary_percent, - max_surge_percent=self.max_surge_percent) - - self.defer(trigger=AnyscaleServiceTrigger(conn_id=self.conn_id, - service_name=self.service_params['name'], - expected_state=ServiceState.RUNNING, - canary_percent=self.canary_percent, - poll_interval=60, - timeout=600), - method_name="execute_complete") - - self.log.info(f"Service rollout id: {service_id}") + service_id = self.hook.deploy_service( + config=svc_config, + in_place=self.in_place, + canary_percent=self.canary_percent, + max_surge_percent=self.max_surge_percent, + ) + + self.defer( + trigger=AnyscaleServiceTrigger( + conn_id=self.conn_id, + service_name=self.service_params["name"], + expected_state=ServiceState.RUNNING, + canary_percent=self.canary_percent, + poll_interval=60, + timeout=600, + ), + method_name="execute_complete", + ) + + self.log.info(f"Service rollout id: {service_id}") return service_id - + def execute_complete(self, context: Context, event: Any) -> None: self.log.info(f"Execution completed...") service_id = event["service_name"] - + if event["status"] == ServiceState.SYSTEM_FAILURE: self.log.info(f"Anyscale service deployment {service_id} ended with status: {event['status']}") raise AirflowException(f"Job {service_id} failed with error {event['message']}") @@ -293,5 +301,3 @@ def execute_complete(self, context: Context, event: Any) -> None: self.log.info(f"Anyscale service deployment {service_id} completed with status: {event['status']}") return None - - diff --git a/anyscale_provider/triggers/anyscale.py b/anyscale_provider/triggers/anyscale.py index 7d04844..0e820bc 100644 --- a/anyscale_provider/triggers/anyscale.py +++ b/anyscale_provider/triggers/anyscale.py @@ -1,24 +1,22 @@ -import time -import logging import asyncio +import time from functools import partial -from datetime import datetime, timedelta -from typing import Any, Dict, AsyncIterator, Tuple, Optional +from typing import Any, AsyncIterator, Dict, Optional, Tuple +from airflow.compat.functools import cached_property +from airflow.triggers.base import BaseTrigger, TriggerEvent from anyscale.job.models import JobState from anyscale.service.models import ServiceState -from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.compat.functools import cached_property - from anyscale_provider.hooks.anyscale import AnyscaleHook + class AnyscaleJobTrigger(BaseTrigger): """ Triggers and monitors the status of a job submitted to Anyscale. - This trigger periodically checks the status of a submitted job on Anyscale and - yields events based on the job's status. It handles timeouts and errors during + This trigger periodically checks the status of a submitted job on Anyscale and + yields events based on the job's status. It handles timeouts and errors during the polling process. .. seealso:: @@ -35,68 +33,76 @@ class AnyscaleJobTrigger(BaseTrigger): """ def __init__(self, conn_id: str, job_id: str, job_start_time: float, poll_interval: int = 60, timeout: int = 3600): - super().__init__() # type: ignore[no-untyped-call] + super().__init__() # type: ignore[no-untyped-call] self.conn_id = conn_id self.job_id = job_id self.job_start_time = job_start_time self.poll_interval = poll_interval self.timeout = timeout self.end_time = time.time() + self.timeout - + @cached_property def hook(self) -> AnyscaleHook: """Return an instance of the AnyscaleHook.""" return AnyscaleHook(conn_id=self.conn_id) def serialize(self) -> Tuple[str, Dict[str, Any]]: - return ("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", { - "conn_id": self.conn_id, - "job_id": self.job_id, - "job_start_time": self.job_start_time, - "poll_interval": self.poll_interval, - "timeout": self.timeout - }) + return ( + "anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "job_start_time": self.job_start_time, + "poll_interval": self.poll_interval, + "timeout": self.timeout, + }, + ) async def run(self) -> AsyncIterator[TriggerEvent]: if not self.job_id: self.log.info("No job_id provided") - yield TriggerEvent({"status": "error", "message": "No job_id provided to async trigger", "job_id": self.job_id}) + yield TriggerEvent( + {"status": "error", "message": "No job_id provided to async trigger", "job_id": self.job_id} + ) try: while not self.is_terminal_status(self.job_id): if time.time() > self.end_time: - yield TriggerEvent({ - "status": "timeout", - "message": f"Timeout waiting for job {self.job_id} to complete.", - "job_id": self.job_id - }) + yield TriggerEvent( + { + "status": "timeout", + "message": f"Timeout waiting for job {self.job_id} to complete.", + "job_id": self.job_id, + } + ) return await asyncio.sleep(self.poll_interval) - + # Fetch and print logs loop = asyncio.get_running_loop() - logs = await loop.run_in_executor( - None, - partial(self.hook.get_logs, job_id=self.job_id) - ) + logs = await loop.run_in_executor(None, partial(self.hook.get_logs, job_id=self.job_id)) for log in logs.split("\n"): self.log.info(log) # Once out of the loop, the job has reached a terminal status job_status = self.get_current_status(self.job_id) self.log.info(f"Current status of the job is {job_status}") - - yield TriggerEvent({ - "status": job_status, - "message": f"Job {self.job_id} completed with status {job_status}.", - "job_id": self.job_id - }) + + yield TriggerEvent( + { + "status": job_status, + "message": f"Job {self.job_id} completed with status {job_status}.", + "job_id": self.job_id, + } + ) except Exception: self.log.exception("An error occurred while polling for job status.") - yield TriggerEvent({ - "status": JobState.FAILED, - "message": "An error occurred while polling for job status.", - "job_id": self.job_id - }) + yield TriggerEvent( + { + "status": JobState.FAILED, + "message": "An error occurred while polling for job status.", + "job_id": self.job_id, + } + ) def get_current_status(self, job_id: str) -> str: job_status = self.hook.get_job_status(job_id=job_id).state @@ -113,7 +119,7 @@ class AnyscaleServiceTrigger(BaseTrigger): Triggers and monitors the status of a service deployment on Anyscale. This trigger periodically checks the status of a service deployment on Anyscale - and yields events based on the service's status. It handles timeouts and errors + and yields events based on the service's status. It handles timeouts and errors during the monitoring process. .. seealso:: @@ -129,14 +135,16 @@ class AnyscaleServiceTrigger(BaseTrigger): :raises AirflowException: If no service_name is provided or an error occurs during monitoring. """ - def __init__(self, - conn_id: str, - service_name: str, - expected_state: str, - canary_percent: Optional[float], - poll_interval: int = 60, - timeout: int = 600): - super().__init__() # type: ignore[no-untyped-call] + def __init__( + self, + conn_id: str, + service_name: str, + expected_state: str, + canary_percent: Optional[float], + poll_interval: int = 60, + timeout: int = 600, + ): + super().__init__() # type: ignore[no-untyped-call] self.conn_id = conn_id self.service_name = service_name self.expected_state = expected_state @@ -151,53 +159,74 @@ def hook(self) -> AnyscaleHook: return AnyscaleHook(conn_id=self.conn_id) def serialize(self) -> Tuple[str, Dict[str, Any]]: - return ("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger", { - "conn_id": self.conn_id, - "service_name": self.service_name, - "expected_state": self.expected_state, - "canary_percent": self.canary_percent, - "poll_interval": self.poll_interval, - "timeout": self.timeout - }) + return ( + "anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger", + { + "conn_id": self.conn_id, + "service_name": self.service_name, + "expected_state": self.expected_state, + "canary_percent": self.canary_percent, + "poll_interval": self.poll_interval, + "timeout": self.timeout, + }, + ) async def run(self) -> AsyncIterator[TriggerEvent]: if not self.service_name: self.log.info("No service_name provided") - yield TriggerEvent({"status": ServiceState.SYSTEM_FAILURE, "message": "No service_name provided to async trigger", "service_name": self.service_name}) + yield TriggerEvent( + { + "status": ServiceState.SYSTEM_FAILURE, + "message": "No service_name provided to async trigger", + "service_name": self.service_name, + } + ) try: - self.log.info(f"Monitoring service {self.service_name} every {self.poll_interval} seconds to reach {self.expected_state}") + self.log.info( + f"Monitoring service {self.service_name} every {self.poll_interval} seconds to reach {self.expected_state}" + ) while self.check_current_status(self.service_name): if time.time() > self.end_time: - yield TriggerEvent({ - "status": ServiceState.UNKNOWN, - "message": f"Service {self.service_name} did not reach {self.expected_state} within the timeout period.", - "service_name": self.service_name - }) + yield TriggerEvent( + { + "status": ServiceState.UNKNOWN, + "message": f"Service {self.service_name} did not reach {self.expected_state} within the timeout period.", + "service_name": self.service_name, + } + ) return - + await asyncio.sleep(self.poll_interval) current_state = self.get_current_status(self.service_name) if current_state == ServiceState.RUNNING: - yield TriggerEvent({"status": ServiceState.RUNNING, - "message": "Service deployment succeeded", - "service_name": self.service_name}) + yield TriggerEvent( + { + "status": ServiceState.RUNNING, + "message": "Service deployment succeeded", + "service_name": self.service_name, + } + ) return elif self.expected_state != current_state and not self.check_current_status(self.service_name): - yield TriggerEvent({ - "status": ServiceState.SYSTEM_FAILURE, - "message": f"Service {self.service_name} entered an unexpected state: {current_state}", - "service_name": self.service_name - }) + yield TriggerEvent( + { + "status": ServiceState.SYSTEM_FAILURE, + "message": f"Service {self.service_name} entered an unexpected state: {current_state}", + "service_name": self.service_name, + } + ) return except Exception as e: self.log.error("An error occurred during monitoring:", exc_info=True) - yield TriggerEvent({"status": ServiceState.SYSTEM_FAILURE, "message": str(e), "service_name": self.service_name}) - + yield TriggerEvent( + {"status": ServiceState.SYSTEM_FAILURE, "message": str(e), "service_name": self.service_name} + ) + def get_current_status(self, service_name: str) -> str: service_status = self.hook.get_service_status(service_name) @@ -208,8 +237,13 @@ def get_current_status(self, service_name: str) -> str: return str(service_status.canary_version.state) else: return str(service_status.state) - + def check_current_status(self, service_name: str) -> bool: job_status = self.get_current_status(service_name) self.log.info(f"Current job status for {service_name} is: {job_status}") - return job_status in (ServiceState.STARTING, ServiceState.UPDATING, ServiceState.ROLLING_OUT, ServiceState.UNHEALTHY) \ No newline at end of file + return job_status in ( + ServiceState.STARTING, + ServiceState.UPDATING, + ServiceState.ROLLING_OUT, + ServiceState.UNHEALTHY, + ) diff --git a/codespell-ignore-words.txt b/codespell-ignore-words.txt new file mode 100644 index 0000000..bf52b4c --- /dev/null +++ b/codespell-ignore-words.txt @@ -0,0 +1 @@ +assertIn diff --git a/pyproject.toml b/pyproject.toml index 7c8a4c5..66ec6e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ freeze = "pip freeze" test = 'sh scripts/test/unit_test.sh' test-cov = 'sh scripts/test/unit_cov.sh' test-integration = 'sh scripts/test/integration_test.sh' -type-check = " pre-commit run mypy --files anyscale_provider/**/*" +static-check = "pre-commit run --all-files" ###################################### # THIRD PARTY TOOLS @@ -103,4 +103,4 @@ ignore = ["F541"] max-complexity = 10 [tool.distutils.bdist_wheel] -universal = true \ No newline at end of file +universal = true diff --git a/scripts/test/integration_test.sh b/scripts/test/integration_test.sh index c8b93da..7dfeb5a 100644 --- a/scripts/test/integration_test.sh +++ b/scripts/test/integration_test.sh @@ -5,4 +5,4 @@ pytest -vv \ --durations=0 \ -m integration \ -s \ - --log-cli-level=DEBUG \ No newline at end of file + --log-cli-level=DEBUG diff --git a/scripts/test/pre-install-airflow.sh b/scripts/test/pre-install-airflow.sh index e9aeb9f..de29703 100644 --- a/scripts/test/pre-install-airflow.sh +++ b/scripts/test/pre-install-airflow.sh @@ -11,4 +11,4 @@ mv /tmp/constraint.txt.tmp /tmp/constraint.txt # Install Airflow with constraints pip install apache-airflow==$AIRFLOW_VERSION --constraint /tmp/constraint.txt pip install pydantic --constraint /tmp/constraint.txt -rm /tmp/constraint.txt \ No newline at end of file +rm /tmp/constraint.txt diff --git a/scripts/test/unit_cov.sh b/scripts/test/unit_cov.sh index f6426fd..1c8f335 100644 --- a/scripts/test/unit_cov.sh +++ b/scripts/test/unit_cov.sh @@ -5,4 +5,4 @@ pytest \ --cov-report=xml \ --durations=0 \ -m "not (integration or perf)" \ - --ignore=tests/dags/test_dag_example.py \ No newline at end of file + --ignore=tests/dags/test_dag_example.py diff --git a/scripts/test/unit_test.sh b/scripts/test/unit_test.sh index 2fd9dc4..4b6dec8 100644 --- a/scripts/test/unit_test.sh +++ b/scripts/test/unit_test.sh @@ -1,4 +1,4 @@ pytest \ -vv \ --durations=0 \ - -m "not (integration or perf)" \ No newline at end of file + -m "not (integration or perf)" diff --git a/tests/dags/example_dags/anyscale_dag.py b/tests/dags/example_dags/anyscale_dag.py index 91ff527..eae3d65 100644 --- a/tests/dags/example_dags/anyscale_dag.py +++ b/tests/dags/example_dags/anyscale_dag.py @@ -1,42 +1,44 @@ from datetime import datetime, timedelta -from airflow import DAG from pathlib import Path + +from airflow import DAG + from anyscale_provider.operators.anyscale import SubmitAnyscaleJob default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': datetime(2024, 4, 2), - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5), + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2024, 4, 2), + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), } # Define the Anyscale connection ANYSCALE_CONN_ID = "anyscale_conn" # Constants -FOLDER_PATH = Path(__file__).parent /"ray_scripts" +FOLDER_PATH = Path(__file__).parent / "ray_scripts" dag = DAG( - 'sample_anyscale_workflow', + "sample_anyscale_workflow", default_args=default_args, - description='A DAG to interact with Anyscale triggered manually', + description="A DAG to interact with Anyscale triggered manually", schedule_interval=None, # This DAG is not scheduled, only triggered manually catchup=False, ) submit_anyscale_job = SubmitAnyscaleJob( - task_id='submit_anyscale_job', - conn_id = ANYSCALE_CONN_ID, - name = 'AstroJob', - image_uri = 'anyscale/ray:2.23.0-py311', - compute_config = 'my-compute-config:1', - working_dir = str(FOLDER_PATH), - entrypoint= 'python script.py', - requirements = ["requests","pandas","numpy","torch"], - max_retries = 1, + task_id="submit_anyscale_job", + conn_id=ANYSCALE_CONN_ID, + name="AstroJob", + image_uri="anyscale/ray:2.23.0-py311", + compute_config="my-compute-config:1", + working_dir=str(FOLDER_PATH), + entrypoint="python script.py", + requirements=["requests", "pandas", "numpy", "torch"], + max_retries=1, dag=dag, ) diff --git a/tests/dags/example_dags/anyscale_service.py b/tests/dags/example_dags/anyscale_service.py index 049daa6..8fa1d79 100644 --- a/tests/dags/example_dags/anyscale_service.py +++ b/tests/dags/example_dags/anyscale_service.py @@ -1,26 +1,28 @@ from datetime import datetime, timedelta + from airflow import DAG from airflow.operators.python import PythonOperator -from anyscale_provider.operators.anyscale import RolloutAnyscaleService + from anyscale_provider.hooks.anyscale import AnyscaleHook +from anyscale_provider.operators.anyscale import RolloutAnyscaleService default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': datetime(2024, 4, 2), - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5), + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2024, 4, 2), + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), } # Define the Anyscale connection ANYSCALE_CONN_ID = "anyscale_conn" dag = DAG( - 'sample_anyscale_service_workflow', + "sample_anyscale_service_workflow", default_args=default_args, - description='A DAG to interact with Anyscale triggered manually', + description="A DAG to interact with Anyscale triggered manually", schedule_interval=None, # This DAG is not scheduled, only triggered manually catchup=False, ) @@ -29,27 +31,28 @@ task_id="rollout_anyscale_service", conn_id=ANYSCALE_CONN_ID, name="AstroService", - image_uri='anyscale/ray:2.23.0-py311', - compute_config='my-compute-config:1', + image_uri="anyscale/ray:2.23.0-py311", + compute_config="my-compute-config:1", working_dir="https://github.com/anyscale/docs_examples/archive/refs/heads/main.zip", applications=[{"import_path": "sentiment_analysis.app:model"}], requirements=["transformers", "requests", "pandas", "numpy", "torch"], in_place=False, canary_percent=None, - dag=dag + dag=dag, ) + def terminate_service(): hook = AnyscaleHook(conn_id=ANYSCALE_CONN_ID) - result = hook.terminate_service(service_id="AstroService", - time_delay=5) + result = hook.terminate_service(service_id="AstroService", time_delay=5) print(result) + terminate_anyscale_service = PythonOperator( - task_id='initialize_anyscale_hook', + task_id="initialize_anyscale_hook", python_callable=terminate_service, dag=dag, ) # Defining the task sequence -deploy_anyscale_service >> terminate_anyscale_service \ No newline at end of file +deploy_anyscale_service >> terminate_anyscale_service diff --git a/tests/dags/example_dags/ray_scripts/ray-serve.py b/tests/dags/example_dags/ray_scripts/ray-serve.py index f3efcdc..8738578 100644 --- a/tests/dags/example_dags/ray_scripts/ray-serve.py +++ b/tests/dags/example_dags/ray_scripts/ray-serve.py @@ -1,8 +1,7 @@ # Filename: local_dev.py -from starlette.requests import Request - from ray import serve -from ray.serve.handle import DeploymentHandle, DeploymentResponse +from ray.serve.handle import DeploymentHandle +from starlette.requests import Request @serve.deployment @@ -23,4 +22,4 @@ async def __call__(self, request: Request): return await self.say_hello_twice(request.query_params["name"]) -app = HelloDeployment.bind(Doubler.bind()) \ No newline at end of file +app = HelloDeployment.bind(Doubler.bind()) diff --git a/tests/dags/example_dags/ray_scripts/script-gpu.py b/tests/dags/example_dags/ray_scripts/script-gpu.py index b37f8bc..db5549f 100644 --- a/tests/dags/example_dags/ray_scripts/script-gpu.py +++ b/tests/dags/example_dags/ray_scripts/script-gpu.py @@ -1,8 +1,10 @@ +import time + import ray import torch -import time -ray.init(address='auto') +ray.init(address="auto") + @ray.remote(num_gpus=1) def gpu_task(): @@ -11,7 +13,7 @@ def gpu_task(): # Create a random tensor and move it to GPU data = torch.randn([1000, 1000]).cuda() # Perform a simple computation (matrix multiplication) on the GPU - result = torch.matmul(data, data.t()) + torch.matmul(data, data.t()) # Simulate some processing time time.sleep(1) # Print a success message @@ -19,6 +21,7 @@ def gpu_task(): else: print("CUDA is not available. This task did not run on a GPU.") + # Running the GPU task gpu_future = gpu_task.remote() diff --git a/tests/dags/example_dags/ray_scripts/script.py b/tests/dags/example_dags/ray_scripts/script.py index e72be1b..ff9d4c0 100644 --- a/tests/dags/example_dags/ray_scripts/script.py +++ b/tests/dags/example_dags/ray_scripts/script.py @@ -1,9 +1,11 @@ # script.py import ray + @ray.remote def hello_world(): return "hello world" + ray.init("auto") -print(ray.get(hello_world.remote())) \ No newline at end of file +print(ray.get(hello_world.remote())) diff --git a/tests/dags/test_dag_example.py b/tests/dags/test_dag_example.py index 3b24d74..8236e84 100644 --- a/tests/dags/test_dag_example.py +++ b/tests/dags/test_dag_example.py @@ -1,18 +1,20 @@ import os -import logging -from contextlib import contextmanager -import pytest from pathlib import Path -from airflow.models import DagBag, Connection + +import pytest +from airflow.models import Connection, DagBag from airflow.utils.db import create_default_connections -from airflow.utils.session import provide_session, create_session +from airflow.utils.session import create_session # Correctly construct the example DAGs directory path EXAMPLE_DAGS_DIR = Path(__file__).parent / "example_dags" print(f"EXAMPLE_DAGS_DIR: {EXAMPLE_DAGS_DIR}") + def get_dags(dag_folder=None): - dag_bag = DagBag(dag_folder=str(dag_folder), include_examples=False) if dag_folder else DagBag(include_examples=False) + dag_bag = ( + DagBag(dag_folder=str(dag_folder), include_examples=False) if dag_folder else DagBag(include_examples=False) + ) def strip_path_prefix(path): return os.path.relpath(path, os.environ.get("AIRFLOW_HOME", "")) @@ -24,23 +26,26 @@ def strip_path_prefix(path): return dags_info + @pytest.fixture(scope="module") def setup_airflow_db(): - os.system('airflow db init') + os.system("airflow db init") # Explicitly create the tables if necessary create_default_connections() with create_session() as session: conn = Connection( conn_id="anyscale_conn", conn_type="anyscale", - extra=f'{{"ANYSCALE_CLI_TOKEN": "{os.environ.get("ANYSCALE_CLI_TOKEN", "")}"}}' + extra=f'{{"ANYSCALE_CLI_TOKEN": "{os.environ.get("ANYSCALE_CLI_TOKEN", "")}"}}', ) session.add(conn) session.commit() + dags = get_dags(EXAMPLE_DAGS_DIR) print(f"Discovered DAGs: {dags}") + @pytest.mark.integration @pytest.mark.parametrize("dag_id,dag,fileloc", dags, ids=[x[2] for x in dags]) def test_dag_runs(setup_airflow_db, dag_id, dag, fileloc): @@ -50,4 +55,4 @@ def test_dag_runs(setup_airflow_db, dag_id, dag, fileloc): dag.test() except Exception as e: print(f"Error running DAG {dag_id}: {e}") - raise e \ No newline at end of file + raise e diff --git a/tests/hooks/test_anyscale_hook.py b/tests/hooks/test_anyscale_hook.py index da01251..f4daf3f 100644 --- a/tests/hooks/test_anyscale_hook.py +++ b/tests/hooks/test_anyscale_hook.py @@ -1,85 +1,80 @@ import json -import pytest from unittest import mock -from unittest.mock import patch, MagicMock +from unittest.mock import patch + +import pytest from airflow.exceptions import AirflowException from airflow.models import Connection -from anyscale.job.models import JobConfig, JobStatus, JobState, JobRunStatus -from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceState +from anyscale.job.models import JobConfig, JobRunStatus, JobState, JobStatus +from anyscale.service.models import ServiceConfig, ServiceState, ServiceStatus + from anyscale_provider.hooks.anyscale import AnyscaleHook API_KEY = "api_key_value" + class TestAnyscaleHook: def setup_method(self): with mock.patch("anyscale_provider.hooks.anyscale.Anyscale"): with mock.patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection") as m: m.return_value = Connection( - conn_id='anyscale_default', - conn_type='http', - host='localhost', + conn_id="anyscale_default", + conn_type="http", + host="localhost", password=API_KEY, - extra=json.dumps({}) + extra=json.dumps({}), ) self.hook = AnyscaleHook() - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection') + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection") @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_api_key_required(self, mock_anyscale, mock_get_connection): mock_get_connection.return_value = Connection( - conn_id='anyscale_default', - conn_type='http', - host='localhost', - password=None, - extra=json.dumps({}) + conn_id="anyscale_default", conn_type="http", host="localhost", password=None, extra=json.dumps({}) ) with pytest.raises(AirflowException) as ctx: AnyscaleHook() assert str(ctx.value) == "Missing API token for connection id anyscale_default" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection') + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection") @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_successful_initialization(self, mock_anyscale, mock_get_connection): mock_get_connection.return_value = Connection( - conn_id='anyscale_default', - conn_type='http', - host='localhost', - password=API_KEY, - extra=json.dumps({}) + conn_id="anyscale_default", conn_type="http", host="localhost", password=API_KEY, extra=json.dumps({}) ) hook = AnyscaleHook() - assert hook.get_connection('anyscale_default').password == API_KEY - - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection') + assert hook.get_connection("anyscale_default").password == API_KEY + + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection") @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_init_with_env_token(self, mock_anyscale, mock_get_connection): with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}): mock_get_connection.return_value = Connection( - conn_id='anyscale_default', - conn_type='http', - host='localhost', + conn_id="anyscale_default", + conn_type="http", + host="localhost", password=None, # No password in connection - extra=json.dumps({}) + extra=json.dumps({}), ) # Mock the Anyscale class to return an instance with the expected auth_token mock_instance = mock_anyscale.return_value mock_instance.auth_token = API_KEY - + hook = AnyscaleHook() assert hook.sdk.auth_token == API_KEY @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_submit_job(self, mock_anyscale): job_config = JobConfig(name="test_job", entrypoint="python script.py") - + # Create a mock SDK instance with a mock job submit method mock_sdk_instance = mock_anyscale.return_value mock_sdk_instance.job.submit.return_value = "test_job_id" - + # Patch the instance's sdk attribute directly self.hook.sdk = mock_sdk_instance - + result = self.hook.submit_job(job_config) mock_sdk_instance.job.submit.assert_called_once_with(config=job_config) @@ -91,7 +86,7 @@ def test_submit_job_error(self, mock_anyscale): mock_sdk_instance = mock_anyscale.return_value mock_sdk_instance.job.submit.side_effect = AirflowException("Submit job failed") - + self.hook.sdk = mock_sdk_instance with pytest.raises(AirflowException) as exc: @@ -102,25 +97,28 @@ def test_submit_job_error(self, mock_anyscale): @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_deploy_service(self, mock_anyscale): - service_config = ServiceConfig(name="test_service", applications=[{"name": "app1", "import_path": "module.optional_submodule:app"}]) - + service_config = ServiceConfig( + name="test_service", applications=[{"name": "app1", "import_path": "module.optional_submodule:app"}] + ) + # Create a mock SDK instance with a mock service deploy method mock_sdk_instance = mock_anyscale.return_value mock_sdk_instance.service.deploy.return_value = "test_service_id" self.hook.sdk = mock_sdk_instance - - result = self.hook.deploy_service(service_config, - in_place=False, - canary_percent=10, - max_surge_percent=20) - - mock_sdk_instance.service.deploy.assert_called_once_with(config=service_config, in_place=False, canary_percent=10, max_surge_percent=20) + + result = self.hook.deploy_service(service_config, in_place=False, canary_percent=10, max_surge_percent=20) + + mock_sdk_instance.service.deploy.assert_called_once_with( + config=service_config, in_place=False, canary_percent=10, max_surge_percent=20 + ) assert result == "test_service_id" - @patch('anyscale_provider.hooks.anyscale.Anyscale') + @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_deploy_service_error(self, mock_anyscale): - service_config = ServiceConfig(name="test_service", applications=[{"name": "app1", "import_path": "module.optional_submodule:app"}]) - + service_config = ServiceConfig( + name="test_service", applications=[{"name": "app1", "import_path": "module.optional_submodule:app"}] + ) + mock_sdk_instance = mock_anyscale.return_value mock_sdk_instance.service.deploy.side_effect = AirflowException("Deploy service failed") self.hook.sdk = mock_sdk_instance @@ -128,7 +126,9 @@ def test_deploy_service_error(self, mock_anyscale): with pytest.raises(AirflowException) as exc: self.hook.deploy_service(service_config, in_place=False, canary_percent=10, max_surge_percent=20) - mock_sdk_instance.service.deploy.assert_called_once_with(config=service_config, in_place=False, canary_percent=10, max_surge_percent=20) + mock_sdk_instance.service.deploy.assert_called_once_with( + config=service_config, in_place=False, canary_percent=10, max_surge_percent=20 + ) assert str(exc.value) == "Deploy service failed" @patch("anyscale_provider.hooks.anyscale.Anyscale") @@ -142,7 +142,7 @@ def test_get_job_status(self, mock_anyscale): name="test_job", config=job_config, state=JobState.SUCCEEDED, - runs=[JobRunStatus(name="test", state=JobState.SUCCEEDED)] + runs=[JobRunStatus(name="test", state=JobState.SUCCEEDED)], ) # Patch the instance's sdk attribute directly @@ -213,8 +213,8 @@ def test_terminate_service_error(self, mock_anyscale): self.hook.terminate_service("test_service_id", time_delay=1) mock_sdk_instance.service.terminate.assert_called_once_with(name="test_service_id") assert str(exc.value) == "Service termination failed with error: Terminate service failed" - - @patch('anyscale_provider.hooks.anyscale.Anyscale') + + @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_get_logs(self, mock_anyscale): mock_sdk_instance = mock_anyscale.return_value @@ -223,10 +223,10 @@ def test_get_logs(self, mock_anyscale): result = self.hook.get_logs("test_job_id") - mock_sdk_instance.job.get_logs.assert_called_once_with(job_id = "test_job_id") + mock_sdk_instance.job.get_logs.assert_called_once_with(job_id="test_job_id") assert result == "job logs" - @patch('anyscale_provider.hooks.anyscale.Anyscale') + @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_get_logs_empty(self, mock_anyscale): mock_sdk_instance = mock_anyscale.return_value mock_sdk_instance.job.get_logs.return_value = "" @@ -234,12 +234,14 @@ def test_get_logs_empty(self, mock_anyscale): result = self.hook.get_logs("test_job_id") - mock_sdk_instance.job.get_logs.assert_called_once_with(job_id = "test_job_id") + mock_sdk_instance.job.get_logs.assert_called_once_with(job_id="test_job_id") assert result == "" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status") def test_get_service_status(self, mock_get_service_status): - mock_service_status = ServiceStatus(id="test_service_id", name="test_service", query_url="http://example.com", state=ServiceState.RUNNING) + mock_service_status = ServiceStatus( + id="test_service_id", name="test_service", query_url="http://example.com", state=ServiceState.RUNNING + ) mock_get_service_status.return_value = mock_service_status result = self.hook.get_service_status("test_service_name") @@ -250,7 +252,7 @@ def test_get_service_status(self, mock_get_service_status): assert result.query_url == "http://example.com" assert result.state == ServiceState.RUNNING - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status") def test_get_service_status_error(self, mock_get_service_status): mock_get_service_status.side_effect = AirflowException("Get service status failed") @@ -261,7 +263,7 @@ def test_get_service_status_error(self, mock_get_service_status): @patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None) def test_terminate_job_with_delay(self, mock_sleep): - with patch.object(self.hook.sdk.job, 'terminate', return_value=None) as mock_terminate: + with patch.object(self.hook.sdk.job, "terminate", return_value=None) as mock_terminate: result = self.hook.terminate_job("test_job_id", time_delay=1) mock_terminate.assert_called_once_with(name="test_job_id") mock_sleep.assert_called_once_with(1) @@ -269,8 +271,8 @@ def test_terminate_job_with_delay(self, mock_sleep): @patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None) def test_terminate_service_with_delay(self, mock_sleep): - with patch.object(self.hook.sdk.service, 'terminate', return_value=None) as mock_terminate: + with patch.object(self.hook.sdk.service, "terminate", return_value=None) as mock_terminate: result = self.hook.terminate_service("test_service_id", time_delay=1) mock_terminate.assert_called_once_with(name="test_service_id") mock_sleep.assert_called_once_with(1) - assert result is True \ No newline at end of file + assert result is True diff --git a/tests/operators/test_anyscale_operators.py b/tests/operators/test_anyscale_operators.py index ef66c6e..2bed794 100644 --- a/tests/operators/test_anyscale_operators.py +++ b/tests/operators/test_anyscale_operators.py @@ -1,140 +1,141 @@ import unittest -from unittest.mock import patch, MagicMock,PropertyMock -from airflow.utils.context import Context +from unittest.mock import MagicMock, PropertyMock, patch + from airflow.exceptions import AirflowException, TaskDeferred +from airflow.utils.context import Context from anyscale.job.models import JobState from anyscale.service.models import ServiceState -from anyscale_provider.operators.anyscale import SubmitAnyscaleJob -from anyscale_provider.operators.anyscale import RolloutAnyscaleService -from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger,AnyscaleServiceTrigger + +from anyscale_provider.operators.anyscale import RolloutAnyscaleService, SubmitAnyscaleJob +from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger class TestSubmitAnyscaleJob(unittest.TestCase): def setUp(self): self.operator = SubmitAnyscaleJob( - conn_id='test_conn', - name='test_job', - image_uri='test_image_uri', - compute_config={}, - working_dir='/test/dir', - entrypoint='test_entrypoint', - task_id='submit_job_test' + conn_id="test_conn", + name="test_job", + image_uri="test_image_uri", + compute_config={}, + working_dir="/test/dir", + entrypoint="test_entrypoint", + task_id="submit_job_test", ) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status') - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock) + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status") + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook", new_callable=MagicMock) def test_execute_successful(self, mock_hook, mock_get_status): job_result_mock = MagicMock() - job_result_mock.id = '123' - mock_hook.submit_job.return_value = '123' + job_result_mock.id = "123" + mock_hook.submit_job.return_value = "123" mock_get_status.return_value = JobState.SUCCEEDED - + job_id = self.operator.execute(Context()) - - mock_hook.submit_job.assert_called_once() - mock_get_status.assert_called_with('123') - self.assertEqual(job_id, '123') - - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.process_job_status') - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status') - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook') + + mock_hook.submit_job.assert_called_once() + mock_get_status.assert_called_with("123") + self.assertEqual(job_id, "123") + + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.process_job_status") + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status") + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_execute_fail_on_status(self, mock_hook, mock_get_current_status, mock_process_job_status): - mock_hook.submit_job.return_value = '123' + mock_hook.submit_job.return_value = "123" mock_get_current_status.return_value = JobState.FAILED mock_process_job_status.side_effect = AirflowException("Job 123 failed.") - + with self.assertRaises(AirflowException) as context: self.operator.execute(Context()) - + self.assertTrue("Job 123 failed." in str(context.exception)) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook') + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_on_kill(self, mock_hook): - self.operator.job_id = '123' - self.operator.on_kill() - mock_hook.terminate_job.assert_called_once_with('123', 5) + self.operator.job_id = "123" + self.operator.on_kill() + mock_hook.terminate_job.assert_called_once_with("123", 5) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook') + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_process_job_status_unexpected_state(self, mock_hook): with self.assertRaises(Exception): - self.operator.process_job_status(None, 'UNKNOWN_STATE') + self.operator.process_job_status(None, "UNKNOWN_STATE") - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.defer_job_polling') - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook') + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.defer_job_polling") + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_defer_job_polling_called(self, mock_hook, mock_defer_job_polling): mock_hook.get_job_status.return_value = JobState.STARTING - self.operator.process_job_status('123', JobState.STARTING) + self.operator.process_job_status("123", JobState.STARTING) mock_defer_job_polling.assert_called_once() - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook') + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_execute_complete(self, mock_hook): - event = {'status': JobState.SUCCEEDED, 'job_id': '123', 'message': 'Job completed successfully'} + event = {"status": JobState.SUCCEEDED, "job_id": "123", "message": "Job completed successfully"} self.assertEqual(self.operator.execute_complete(Context(), event), None) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook') + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_execute_complete_failure(self, mock_hook): - event = {'status': JobState.FAILED, 'job_id': '123', 'message': 'Job failed with error'} + event = {"status": JobState.FAILED, "job_id": "123", "message": "Job failed with error"} with self.assertRaises(AirflowException) as context: self.operator.execute_complete(Context(), event) self.assertTrue("Job 123 failed with error" in str(context.exception)) - + def test_no_job_name(self): with self.assertRaises(AirflowException) as context: SubmitAnyscaleJob( - conn_id='test_conn', - name='', # No job name - image_uri='test_image_uri', - compute_config={}, - working_dir='/test/dir', - entrypoint='test_entrypoint', - task_id='submit_job_test' + conn_id="test_conn", + name="", # No job name + image_uri="test_image_uri", + compute_config={}, + working_dir="/test/dir", + entrypoint="test_entrypoint", + task_id="submit_job_test", ) self.assertTrue("Job name is required." in str(context.exception)) def test_no_entrypoint_provided(self): with self.assertRaises(AirflowException) as context: SubmitAnyscaleJob( - conn_id='test_conn', - name='test_job', - image_uri='test_image_uri', - compute_config={}, - working_dir='/test/dir', - entrypoint='', # No entrypoint - task_id='submit_job_test' + conn_id="test_conn", + name="test_job", + image_uri="test_image_uri", + compute_config={}, + working_dir="/test/dir", + entrypoint="", # No entrypoint + task_id="submit_job_test", ) self.assertTrue("Entrypoint must be specified." in str(context.exception)) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock) + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook", new_callable=PropertyMock) def test_check_anyscale_hook(self, mock_hook_property): # Access the hook property - hook = self.operator.hook + self.operator.hook # Verify that the hook property was accessed mock_hook_property.assert_called_once() - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock) + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook", new_callable=PropertyMock) def test_execute_with_no_hook(self, mock_hook_property): # Simulate the hook not being available by raising an AirflowException mock_hook_property.side_effect = AirflowException("SDK is not available.") - + # Execute the operator and expect it to raise an AirflowException with self.assertRaises(AirflowException) as context: self.operator.execute(Context()) self.assertTrue("SDK is not available." in str(context.exception)) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status') - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock) + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status") + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook", new_callable=MagicMock) def test_job_state_failed(self, mock_hook, mock_get_status): job_result_mock = MagicMock() - job_result_mock.id = '123' - mock_hook.submit_job.return_value = '123' + job_result_mock.id = "123" + mock_hook.submit_job.return_value = "123" mock_get_status.return_value = JobState.FAILED - + with self.assertRaises(AirflowException) as context: self.operator.execute(Context()) self.assertTrue("Job 123 failed." in str(context.exception)) - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock) + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook", new_callable=PropertyMock) def test_get_current_status(self, mock_hook_property): mock_hook = MagicMock() mock_job_status = MagicMock(state=JobState.SUCCEEDED) @@ -142,19 +143,19 @@ def test_get_current_status(self, mock_hook_property): mock_hook_property.return_value = mock_hook # Call the method to test - status = self.operator.get_current_status('123') + status = self.operator.get_current_status("123") # Verify the result - self.assertEqual(status, 'SUCCEEDED') + self.assertEqual(status, "SUCCEEDED") # Ensure the mock was called correctly - mock_hook.get_job_status.assert_called_once_with(job_id='123') - - @patch('airflow.models.BaseOperator.defer') - @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock) + mock_hook.get_job_status.assert_called_once_with(job_id="123") + + @patch("airflow.models.BaseOperator.defer") + @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook", new_callable=MagicMock) def test_defer_job_polling(self, mock_hook, mock_defer): # Mock the submit_job method to return a job ID - mock_hook.submit_job.return_value = '123' + mock_hook.submit_job.return_value = "123" # Mock the get_job_status method to return a starting state mock_hook.get_job_status.return_value.state = JobState.STARTING @@ -164,31 +165,31 @@ def test_defer_job_polling(self, mock_hook, mock_defer): # Check that the defer method was called with the correct arguments mock_defer.assert_called_once() args, kwargs = mock_defer.call_args - self.assertIsInstance(kwargs['trigger'], AnyscaleJobTrigger) - self.assertEqual(kwargs['trigger'].job_id, '123') - self.assertEqual(kwargs['trigger'].conn_id, 'test_conn') - self.assertEqual(kwargs['method_name'], 'execute_complete') - + self.assertIsInstance(kwargs["trigger"], AnyscaleJobTrigger) + self.assertEqual(kwargs["trigger"].job_id, "123") + self.assertEqual(kwargs["trigger"].conn_id, "test_conn") + self.assertEqual(kwargs["method_name"], "execute_complete") + class TestRolloutAnyscaleService(unittest.TestCase): def setUp(self): self.operator = RolloutAnyscaleService( - conn_id='test_conn', - name='test_service', - image_uri='test_image_uri', - working_dir='/test/dir', - applications=[{'name': 'app1', 'import_path': 'module.optional_submodule:app'}], - compute_config='config123', - task_id='rollout_service_test' + conn_id="test_conn", + name="test_service", + image_uri="test_image_uri", + working_dir="/test/dir", + applications=[{"name": "app1", "import_path": "module.optional_submodule:app"}], + compute_config="config123", + task_id="rollout_service_test", ) - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook') + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook") def test_execute_successful(self, mock_hook): - mock_hook.return_value.deploy_service.return_value = 'service123' + mock_hook.return_value.deploy_service.return_value = "service123" with self.assertRaises(TaskDeferred): self.operator.execute(Context()) - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook', new_callable=MagicMock) + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook", new_callable=MagicMock) def test_execute_fail_sdk_unavailable(self, mock_hook): self.operator.hook = None @@ -197,32 +198,32 @@ def test_execute_fail_sdk_unavailable(self, mock_hook): self.assertEqual(str(cm.exception), "SDK is not available") - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.defer') - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook', new_callable=MagicMock) + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.defer") + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook", new_callable=MagicMock) def test_defer_trigger_called(self, mock_hook, mock_defer): - mock_hook.return_value.deploy_service.return_value = 'service123' - + mock_hook.return_value.deploy_service.return_value = "service123" + self.operator.execute(Context()) - + # Extract the actual call arguments actual_call_args = mock_defer.call_args - + # Define the expected trigger and method_name expected_trigger = AnyscaleServiceTrigger( - conn_id='test_conn', - service_name='test_service', + conn_id="test_conn", + service_name="test_service", expected_state=ServiceState.RUNNING, canary_percent=None, poll_interval=60, - timeout=600 + timeout=600, ) - + expected_method_name = "execute_complete" - + # Perform individual assertions - actual_trigger = actual_call_args.kwargs['trigger'] - actual_method_name = actual_call_args.kwargs['method_name'] - + actual_trigger = actual_call_args.kwargs["trigger"] + actual_method_name = actual_call_args.kwargs["method_name"] + self.assertEqual(actual_trigger.conn_id, expected_trigger.conn_id) self.assertEqual(actual_trigger.service_name, expected_trigger.service_name) self.assertEqual(actual_trigger.expected_state, expected_trigger.expected_state) @@ -230,49 +231,50 @@ def test_defer_trigger_called(self, mock_hook, mock_defer): self.assertEqual(actual_trigger.timeout, expected_trigger.timeout) self.assertEqual(actual_method_name, expected_method_name) - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook') + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook") def test_execute_complete_failed(self, mock_hook): - event = {'status': ServiceState.SYSTEM_FAILURE, 'service_name': 'service123', 'message': 'Deployment failed'} + event = {"status": ServiceState.SYSTEM_FAILURE, "service_name": "service123", "message": "Deployment failed"} with self.assertRaises(AirflowException) as cm: self.operator.execute_complete(Context(), event) self.assertIn("Job service123 failed with error Deployment failed", str(cm.exception)) - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook') + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook") def test_execute_complete_success(self, mock_hook): - event = {'status': ServiceState.RUNNING, 'service_name': 'service123', 'message': 'Deployment succeeded'} + event = {"status": ServiceState.RUNNING, "service_name": "service123", "message": "Deployment succeeded"} self.operator.execute_complete(Context(), event) - self.assertEqual(self.operator.service_params['name'], 'test_service') + self.assertEqual(self.operator.service_params["name"], "test_service") - @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook', new_callable=PropertyMock) + @patch("anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook", new_callable=PropertyMock) def test_check_anyscale_hook(self, mock_hook_property): - hook = self.operator.hook + self.operator.hook mock_hook_property.assert_called_once() - + def test_no_service_name(self): with self.assertRaises(ValueError) as cm: RolloutAnyscaleService( - conn_id='test_conn', - name='', # No service name - image_uri='test_image_uri', - working_dir='/test/dir', - applications=[{'name': 'app1', 'import_path': 'module.optional_submodule:app'}], - compute_config='config123', - task_id='rollout_service_test' + conn_id="test_conn", + name="", # No service name + image_uri="test_image_uri", + working_dir="/test/dir", + applications=[{"name": "app1", "import_path": "module.optional_submodule:app"}], + compute_config="config123", + task_id="rollout_service_test", ) self.assertIn("Service name is required", str(cm.exception)) def test_no_applications(self): with self.assertRaises(ValueError) as cm: RolloutAnyscaleService( - conn_id='test_conn', - name='test_service', - image_uri='test_image_uri', - working_dir='/test/dir', + conn_id="test_conn", + name="test_service", + image_uri="test_image_uri", + working_dir="/test/dir", applications=[], # No applications - compute_config='config123', - task_id='rollout_service_test' + compute_config="config123", + task_id="rollout_service_test", ) self.assertIn("At least one application must be specified", str(cm.exception)) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/triggers/test_anyscale_triggers.py b/tests/triggers/test_anyscale_triggers.py index cd3deec..9eee662 100644 --- a/tests/triggers/test_anyscale_triggers.py +++ b/tests/triggers/test_anyscale_triggers.py @@ -1,149 +1,140 @@ -import unittest -from unittest.mock import patch, MagicMock, PropertyMock import asyncio +import unittest from datetime import datetime -import os -import time -import pytest -from typing import Any, Dict, AsyncIterator, Tuple, Optional -from pathlib import Path -from airflow.exceptions import AirflowNotFoundException +from unittest.mock import MagicMock, PropertyMock, patch -from anyscale.job.models import JobState, JobStatus, JobConfig, JobRunStatus -from anyscale.service.models import ServiceState, ServiceStatus +from airflow.exceptions import AirflowNotFoundException +from anyscale.job.models import JobState +from anyscale.service.models import ServiceState -from anyscale_provider.hooks.anyscale import AnyscaleHook from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger -from airflow.triggers.base import TriggerEvent -from airflow.models.connection import Connection + class TestAnyscaleJobTrigger(unittest.TestCase): def setUp(self): - self.trigger = AnyscaleJobTrigger(conn_id='anyscale_default', - job_id='123', - job_start_time=datetime.now().timestamp()) + self.trigger = AnyscaleJobTrigger( + conn_id="anyscale_default", job_id="123", job_start_time=datetime.now().timestamp() + ) - @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status') + @patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status") def test_is_terminal_status(self, mock_get_status): - mock_get_status.return_value = 'SUCCEEDED' - self.assertTrue(self.trigger.is_terminal_status('123')) + mock_get_status.return_value = "SUCCEEDED" + self.assertTrue(self.trigger.is_terminal_status("123")) - @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status') + @patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status") def test_is_not_terminal_status(self, mock_get_status): - mock_get_status.return_value = 'RUNNING' - self.assertFalse(self.trigger.is_terminal_status('123')) + mock_get_status.return_value = "RUNNING" + self.assertFalse(self.trigger.is_terminal_status("123")) - @patch('asyncio.sleep', return_value=None) - @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status', side_effect=['RUNNING', 'RUNNING', 'SUCCEEDED']) + @patch("asyncio.sleep", return_value=None) + @patch( + "anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status", + side_effect=["RUNNING", "RUNNING", "SUCCEEDED"], + ) async def test_run_successful_completion(self, mock_get_status, mock_sleep): events = [] async for event in self.trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0].payload['status'], 'SUCCEEDED') + self.assertEqual(events[0].payload["status"], "SUCCEEDED") - @patch('time.time', side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout - @patch('asyncio.sleep', return_value=None) + @patch("time.time", side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout + @patch("asyncio.sleep", return_value=None) async def test_run_timeout(self, mock_sleep, mock_time): events = [] async for event in self.trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0].payload['status'], 'timeout') + self.assertEqual(events[0].payload["status"], "timeout") - @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', side_effect=Exception("Error occurred")) + @patch( + "anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status", + side_effect=Exception("Error occurred"), + ) async def test_run_exception(self, mock_is_terminal_status): events = [] async for event in self.trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0].payload['status'], JobState.FAILED) - self.assertIn('Error occurred', events[0].payload['message']) + self.assertEqual(events[0].payload["status"], JobState.FAILED) + self.assertIn("Error occurred", events[0].payload["message"]) - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status') + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status") def test_get_current_status(self, mock_get_job_status): mock_get_job_status.return_value = MagicMock(state=JobState.SUCCEEDED) - trigger = AnyscaleJobTrigger(conn_id='default_conn', - job_id='123', - job_start_time=datetime.now().timestamp()) + trigger = AnyscaleJobTrigger(conn_id="default_conn", job_id="123", job_start_time=datetime.now().timestamp()) # Mock the hook property to return our mocked hook - with patch.object(AnyscaleJobTrigger, 'hook', new_callable=PropertyMock) as mock_hook: + with patch.object(AnyscaleJobTrigger, "hook", new_callable=PropertyMock) as mock_hook: mock_hook.return_value.get_job_status = mock_get_job_status - + # Call the method to test - status = trigger.get_current_status('123') + status = trigger.get_current_status("123") # Verify the result - self.assertEqual(status, 'SUCCEEDED') - + self.assertEqual(status, "SUCCEEDED") + # Ensure the mock was called correctly - mock_get_job_status.assert_called_once_with(job_id='123') + mock_get_job_status.assert_called_once_with(job_id="123") - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs') - @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status', side_effect=['RUNNING', 'SUCCEEDED']) - @patch('asyncio.sleep', return_value=None) + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs") + @patch( + "anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status", + side_effect=["RUNNING", "SUCCEEDED"], + ) + @patch("asyncio.sleep", return_value=None) async def test_run_with_logs(self, mock_sleep, mock_get_status, mock_get_logs): mock_get_logs.return_value = "log line 1\nlog line 2" events = [] async for event in self.trigger.run(): events.append(event) - + self.assertEqual(len(events), 1) - self.assertEqual(events[0].payload['status'], 'SUCCEEDED') + self.assertEqual(events[0].payload["status"], "SUCCEEDED") async def test_run_no_job_id_provided(self): - trigger = AnyscaleJobTrigger(conn_id='default_conn', - job_id='', - job_start_time=datetime.now().timestamp()) + trigger = AnyscaleJobTrigger(conn_id="default_conn", job_id="", job_start_time=datetime.now().timestamp()) events = [] async for event in trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0].payload['status'], 'error') - self.assertIn("No job_id provided to async trigger", events[0].payload['message']) - - @patch('airflow.models.connection.Connection.get_connection_from_secrets') + self.assertEqual(events[0].payload["status"], "error") + self.assertIn("No job_id provided to async trigger", events[0].payload["message"]) + + @patch("airflow.models.connection.Connection.get_connection_from_secrets") def test_hook_method(self, mock_get_connection): # Configure the mock to raise AirflowNotFoundException mock_get_connection.side_effect = AirflowNotFoundException("The conn_id `default_conn` isn't defined") - trigger = AnyscaleJobTrigger(conn_id='default_conn', - job_id='123', - job_start_time=datetime.now().timestamp()) - + trigger = AnyscaleJobTrigger(conn_id="default_conn", job_id="123", job_start_time=datetime.now().timestamp()) + with self.assertRaises(AirflowNotFoundException) as context: - result = trigger.hook() + trigger.hook() self.assertIn("The conn_id `default_conn` isn't defined", str(context.exception)) - + def test_serialize(self): time = datetime.now().timestamp() - trigger = AnyscaleJobTrigger(conn_id='default_conn', - job_id='123', - job_start_time=time) - + trigger = AnyscaleJobTrigger(conn_id="default_conn", job_id="123", job_start_time=time) + result = trigger.serialize() - expected_output = ("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", { - "conn_id": 'default_conn', - "job_id": '123', - "job_start_time": time, - "poll_interval": 60, - "timeout": 3600 - }) + expected_output = ( + "anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", + {"conn_id": "default_conn", "job_id": "123", "job_start_time": time, "poll_interval": 60, "timeout": 3600}, + ) # Check if the result is a tuple self.assertTrue(isinstance(result, tuple)) - + # Check if the tuple contains a string and a dictionary self.assertTrue(isinstance(result[0], str)) self.assertTrue(isinstance(result[1], dict)) - + # Check if the result matches the expected output self.assertEqual(result, expected_output) @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status") @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs") - @patch('asyncio.sleep', return_value=None) + @patch("asyncio.sleep", return_value=None) async def test_anyscale_run_trigger(self, mocked_sleep, mocked_get_logs, mocked_get_job_status): """Test AnyscaleJobTrigger run method with mocked details.""" mocked_get_job_status.return_value.state = JobState.SUCCEEDED @@ -168,146 +159,152 @@ async def test_anyscale_run_trigger(self, mocked_sleep, mocked_get_logs, mocked_ self.assertEqual(result.payload["status"], JobState.SUCCEEDED) self.assertEqual(result.payload["message"], "Job 1234 completed with status JobState.SUCCEEDED.") self.assertEqual(result.payload["job_id"], "1234") - + class TestAnyscaleServiceTrigger(unittest.TestCase): def setUp(self): - self.trigger = AnyscaleServiceTrigger(conn_id='default_conn', - service_name='service123', - expected_state='RUNNING', - canary_percent=None) + self.trigger = AnyscaleServiceTrigger( + conn_id="default_conn", service_name="service123", expected_state="RUNNING", canary_percent=None + ) - @patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.get_current_status') + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.get_current_status") def test_check_current_status(self, mock_get_status): mock_get_status.return_value = "STARTING" - self.assertTrue(self.trigger.check_current_status('service123')) + self.assertTrue(self.trigger.check_current_status("service123")) - @patch('asyncio.sleep', return_value=None) - @patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.get_current_status', side_effect=['STARTING', 'UPDATING', 'RUNNING']) + @patch("asyncio.sleep", return_value=None) + @patch( + "anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.get_current_status", + side_effect=["STARTING", "UPDATING", "RUNNING"], + ) async def test_run_successful(self, mock_get_status, mock_sleep): events = [] async for event in self.trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0]['status'], ServiceState.RUNNING) - self.assertIn('Service deployment succeeded', events[0]['message']) + self.assertEqual(events[0]["status"], ServiceState.RUNNING) + self.assertIn("Service deployment succeeded", events[0]["message"]) - @patch('time.time', side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout - @patch('asyncio.sleep', return_value=None) + @patch("time.time", side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout + @patch("asyncio.sleep", return_value=None) async def test_run_timeout(self, mock_sleep, mock_time): events = [] async for event in self.trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0]['status'], ServiceState.UNKNOWN) - self.assertIn('did not reach RUNNING within the timeout period', events[0]['message']) + self.assertEqual(events[0]["status"], ServiceState.UNKNOWN) + self.assertIn("did not reach RUNNING within the timeout period", events[0]["message"]) - @patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.check_current_status', side_effect=Exception("Error occurred")) + @patch( + "anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.check_current_status", + side_effect=Exception("Error occurred"), + ) async def test_run_exception(self, mock_check_current_status): events = [] async for event in self.trigger.run(): events.append(event) self.assertEqual(len(events), 1) - self.assertEqual(events[0]['status'], ServiceState.SYSTEM_FAILURE) - self.assertIn('Error occurred', events[0]['message']) - - @patch('airflow.models.connection.Connection.get_connection_from_secrets') + self.assertEqual(events[0]["status"], ServiceState.SYSTEM_FAILURE) + self.assertIn("Error occurred", events[0]["message"]) + + @patch("airflow.models.connection.Connection.get_connection_from_secrets") def test_hook_method(self, mock_get_connection): # Configure the mock to raise AirflowNotFoundException mock_get_connection.side_effect = AirflowNotFoundException("The conn_id `default_conn` isn't defined") - trigger = AnyscaleServiceTrigger(conn_id='default_conn', - service_name="AstroService", - expected_state=ServiceState.RUNNING, - canary_percent=0.0) - + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", service_name="AstroService", expected_state=ServiceState.RUNNING, canary_percent=0.0 + ) + with self.assertRaises(AirflowNotFoundException) as context: - result = trigger.hook() + trigger.hook() self.assertIn("The conn_id `default_conn` isn't defined", str(context.exception)) - + def test_serialize(self): - - trigger = AnyscaleServiceTrigger(conn_id='default_conn', - service_name="AstroService", - expected_state=ServiceState.RUNNING, - canary_percent=0.0) - + + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", service_name="AstroService", expected_state=ServiceState.RUNNING, canary_percent=0.0 + ) + result = trigger.serialize() - expected_output = ("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger", { - "conn_id": 'default_conn', - "service_name": "AstroService", - "expected_state": ServiceState.RUNNING, - "canary_percent": 0.0, - "poll_interval": 60, - "timeout": 600 - }) + expected_output = ( + "anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger", + { + "conn_id": "default_conn", + "service_name": "AstroService", + "expected_state": ServiceState.RUNNING, + "canary_percent": 0.0, + "poll_interval": 60, + "timeout": 600, + }, + ) # Check if the result is a tuple self.assertTrue(isinstance(result, tuple)) - + # Check if the tuple contains a string and a dictionary self.assertTrue(isinstance(result[0], str)) self.assertTrue(isinstance(result[1], dict)) - + # Check if the result matches the expected output self.assertEqual(result, expected_output) - - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') + + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status") def test_get_current_status_canary_0_percent(self, mock_get_service_status): # Mock the return value of get_service_status mock_service_status = MagicMock() mock_service_status.state = ServiceState.RUNNING mock_service_status.canary_version.state = ServiceState.RUNNING mock_get_service_status.return_value = mock_service_status - + # Initialize the trigger with canary_percent set to 0.0 - trigger = AnyscaleServiceTrigger(conn_id='default_conn', - service_name="AstroService", - expected_state=ServiceState.RUNNING, - canary_percent=0.0) - + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", service_name="AstroService", expected_state=ServiceState.RUNNING, canary_percent=0.0 + ) + # Mock the hook property to return our mocked hook - with patch.object(AnyscaleServiceTrigger, 'hook', new_callable=PropertyMock) as mock_hook: + with patch.object(AnyscaleServiceTrigger, "hook", new_callable=PropertyMock) as mock_hook: mock_hook.return_value.get_service_status = mock_get_service_status - + # Call the method to test - status = trigger.get_current_status('AstroService') - + status = trigger.get_current_status("AstroService") + # Verify the result - self.assertEqual(status, 'RUNNING') - + self.assertEqual(status, "RUNNING") + # Ensure the mock was called correctly - mock_get_service_status.assert_called_once_with('AstroService') - - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') + mock_get_service_status.assert_called_once_with("AstroService") + + @patch("anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status") def test_get_current_status_canary_100_percent(self, mock_get_service_status): # Mock the return value of get_service_status mock_service_status = MagicMock() mock_service_status.state = ServiceState.TERMINATED mock_service_status.canary_version.state = ServiceState.RUNNING mock_get_service_status.return_value = mock_service_status - + # Initialize the trigger with canary_percent set to 100.0 - trigger = AnyscaleServiceTrigger(conn_id='default_conn', - service_name="AstroService", - expected_state=ServiceState.RUNNING, - canary_percent=100.0) - + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", + service_name="AstroService", + expected_state=ServiceState.RUNNING, + canary_percent=100.0, + ) + # Mock the hook property to return our mocked hook - with patch.object(AnyscaleServiceTrigger, 'hook', new_callable=PropertyMock) as mock_hook: + with patch.object(AnyscaleServiceTrigger, "hook", new_callable=PropertyMock) as mock_hook: mock_hook.return_value.get_service_status = mock_get_service_status - + # Call the method to test - status = trigger.get_current_status('AstroService') - + status = trigger.get_current_status("AstroService") + # Verify the result - self.assertEqual(status, 'TERMINATED') - + self.assertEqual(status, "TERMINATED") + # Ensure the mock was called correctly - mock_get_service_status.assert_called_once_with('AstroService') - + mock_get_service_status.assert_called_once_with("AstroService") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main()