diff --git a/tests/triggers/test_anyscale_triggers.py b/tests/triggers/test_anyscale_triggers.py index b4f77d0..5b0fb2d 100644 --- a/tests/triggers/test_anyscale_triggers.py +++ b/tests/triggers/test_anyscale_triggers.py @@ -1,39 +1,55 @@ import unittest from unittest.mock import patch, MagicMock +import asyncio from datetime import datetime from anyscale.job.models import JobState +from anyscale.service.models import ServiceState -from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger,AnyscaleServiceTrigger +from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger class TestAnyscaleJobTrigger(unittest.TestCase): def setUp(self): self.trigger = AnyscaleJobTrigger(conn_id='default_conn', job_id='123', - job_start_time=datetime.now()) + job_start_time=datetime.now().timestamp()) @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status') def test_is_terminal_status(self, mock_get_status): mock_get_status.return_value = 'COMPLETED' self.assertTrue(self.trigger.is_terminal_status('123')) - @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status') + @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status') def test_is_not_terminal_status(self, mock_get_status): - mock_get_status.return_value = False + mock_get_status.return_value = 'RUNNING' self.assertFalse(self.trigger.is_terminal_status('123')) @patch('asyncio.sleep', return_value=None) @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status', side_effect=['RUNNING', 'RUNNING', 'COMPLETED']) async def test_run_successful_completion(self, mock_get_status, mock_sleep): + events = [] async for event in self.trigger.run(): - self.assertIn('status', event) - self.assertEqual(event['status'], 'COMPLETED') + events.append(event) + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['status'], 'COMPLETED') @patch('time.time', side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout @patch('asyncio.sleep', return_value=None) async def test_run_timeout(self, mock_sleep, mock_time): + events = [] + async for event in self.trigger.run(): + events.append(event) + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['status'], 'timeout') + + @patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', side_effect=Exception("Error occurred")) + async def test_run_exception(self, mock_is_terminal_status): + events = [] async for event in self.trigger.run(): - self.assertEqual(event['status'], 'timeout') + events.append(event) + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['status'], JobState.FAILED) + self.assertIn('Error occurred', events[0]['message']) class TestAnyscaleServiceTrigger(unittest.TestCase): def setUp(self): @@ -50,15 +66,31 @@ def test_check_current_status(self, mock_get_status): @patch('asyncio.sleep', return_value=None) @patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.get_current_status', side_effect=['STARTING', 'UPDATING', 'RUNNING']) async def test_run_successful(self, mock_get_status, mock_sleep): + events = [] async for event in self.trigger.run(): - self.assertEqual(event['status'], 'success') - self.assertIn('Service deployment succeeded', event['message']) + events.append(event) + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['status'], ServiceState.RUNNING) + self.assertIn('Service deployment succeeded', events[0]['message']) @patch('time.time', side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout @patch('asyncio.sleep', return_value=None) async def test_run_timeout(self, mock_sleep, mock_time): + events = [] + async for event in self.trigger.run(): + events.append(event) + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['status'], ServiceState.UNKNOWN) + self.assertIn('did not reach RUNNING within the timeout period', events[0]['message']) + + @patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.check_current_status', side_effect=Exception("Error occurred")) + async def test_run_exception(self, mock_check_current_status): + events = [] async for event in self.trigger.run(): - self.assertEqual(event['status'], 'timeout') + events.append(event) + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['status'], ServiceState.SYSTEM_FAILURE) + self.assertIn('Error occurred', events[0]['message']) if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file