diff --git a/tests/triggers/test_anyscale_triggers.py b/tests/triggers/test_anyscale_triggers.py index 59252f6..f91a90b 100644 --- a/tests/triggers/test_anyscale_triggers.py +++ b/tests/triggers/test_anyscale_triggers.py @@ -4,14 +4,14 @@ from datetime import datetime import os import pytest +from typing import Any, Dict, AsyncIterator, Tuple, Optional from pathlib import Path -from airflow.models import DagBag, Connection -from airflow.utils.db import create_default_connections -from airflow.utils.session import provide_session, create_session +from airflow.exceptions import AirflowNotFoundException from anyscale.job.models import JobState from anyscale.service.models import ServiceState +from anyscale_provider.hooks.anyscale import AnyscaleHook from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger from airflow.triggers.base import TriggerEvent from airflow.models.connection import Connection @@ -60,8 +60,15 @@ async def test_run_exception(self, mock_is_terminal_status): self.assertIn('Error occurred', events[0]['message']) @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status') - def test_get_current_status(self, mock_get_job_status): + async def test_get_current_status(self, mock_get_job_status): mock_get_job_status.return_value = MagicMock(state=JobState.SUCCEEDED) + trigger = AnyscaleJobTrigger(conn_id='default_conn', + job_id='', + job_start_time=datetime.now().timestamp()) + events = [] + async for event in trigger.run(): + events.append(event) + status = self.trigger.get_current_status('123') self.assertEqual(status, JobState.SUCCEEDED) mock_get_job_status.assert_called_once_with(job_id='123') @@ -76,6 +83,45 @@ async def test_run_no_job_id_provided(self): self.assertEqual(len(events), 1) self.assertEqual(events[0]['status'], 'error') self.assertIn('No job_id provided', events[0]['message']) + + @patch('airflow.models.connection.Connection.get_connection_from_secrets') + def test_hook_method(self, mock_get_connection): + # Configure the mock to raise AirflowNotFoundException + mock_get_connection.side_effect = AirflowNotFoundException("The conn_id `default_conn` isn't defined") + + trigger = AnyscaleJobTrigger(conn_id='default_conn', + job_id='123', + job_start_time=datetime.now().timestamp()) + + with self.assertRaises(AirflowNotFoundException) as context: + result = trigger.hook() + + self.assertIn("The conn_id `default_conn` isn't defined", str(context.exception)) + + def test_serialize(self): + time = datetime.now().timestamp() + trigger = AnyscaleJobTrigger(conn_id='default_conn', + job_id='123', + job_start_time=time) + + result = trigger.serialize() + expected_output = ("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", { + "conn_id": 'default_conn', + "job_id": '123', + "job_start_time": time, + "poll_interval": 60, + "timeout": 3600 + }) + + # Check if the result is a tuple + self.assertTrue(isinstance(result, tuple)) + + # Check if the tuple contains a string and a dictionary + self.assertTrue(isinstance(result[0], str)) + self.assertTrue(isinstance(result[1], dict)) + + # Check if the result matches the expected output + self.assertEqual(result, expected_output) class TestAnyscaleServiceTrigger(unittest.TestCase): def setUp(self):