From 5d1c284cd83e11aa72b1fef81a49cbef3d0e12d3 Mon Sep 17 00:00:00 2001 From: Venkat Date: Wed, 12 Jun 2024 23:46:08 -0400 Subject: [PATCH] Added additional unit tests --- tests/operators/test_anyscale_operators.py | 124 ++++++++++++++++++++- tests/triggers/test_anyscale_triggers.py | 36 ++++-- 2 files changed, 151 insertions(+), 9 deletions(-) diff --git a/tests/operators/test_anyscale_operators.py b/tests/operators/test_anyscale_operators.py index 4a40e1e..ef66c6e 100644 --- a/tests/operators/test_anyscale_operators.py +++ b/tests/operators/test_anyscale_operators.py @@ -1,12 +1,12 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock,PropertyMock from airflow.utils.context import Context from airflow.exceptions import AirflowException, TaskDeferred from anyscale.job.models import JobState from anyscale.service.models import ServiceState from anyscale_provider.operators.anyscale import SubmitAnyscaleJob from anyscale_provider.operators.anyscale import RolloutAnyscaleService -from anyscale_provider.triggers.anyscale import AnyscaleServiceTrigger +from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger,AnyscaleServiceTrigger class TestSubmitAnyscaleJob(unittest.TestCase): @@ -77,6 +77,97 @@ def test_execute_complete_failure(self, mock_hook): with self.assertRaises(AirflowException) as context: self.operator.execute_complete(Context(), event) self.assertTrue("Job 123 failed with error" in str(context.exception)) + + def test_no_job_name(self): + with self.assertRaises(AirflowException) as context: + SubmitAnyscaleJob( + conn_id='test_conn', + name='', # No job name + image_uri='test_image_uri', + compute_config={}, + working_dir='/test/dir', + entrypoint='test_entrypoint', + task_id='submit_job_test' + ) + self.assertTrue("Job name is required." in str(context.exception)) + + def test_no_entrypoint_provided(self): + with self.assertRaises(AirflowException) as context: + SubmitAnyscaleJob( + conn_id='test_conn', + name='test_job', + image_uri='test_image_uri', + compute_config={}, + working_dir='/test/dir', + entrypoint='', # No entrypoint + task_id='submit_job_test' + ) + self.assertTrue("Entrypoint must be specified." in str(context.exception)) + + @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock) + def test_check_anyscale_hook(self, mock_hook_property): + # Access the hook property + hook = self.operator.hook + # Verify that the hook property was accessed + mock_hook_property.assert_called_once() + + @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock) + def test_execute_with_no_hook(self, mock_hook_property): + # Simulate the hook not being available by raising an AirflowException + mock_hook_property.side_effect = AirflowException("SDK is not available.") + + # Execute the operator and expect it to raise an AirflowException + with self.assertRaises(AirflowException) as context: + self.operator.execute(Context()) + + self.assertTrue("SDK is not available." in str(context.exception)) + + @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status') + @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock) + def test_job_state_failed(self, mock_hook, mock_get_status): + job_result_mock = MagicMock() + job_result_mock.id = '123' + mock_hook.submit_job.return_value = '123' + mock_get_status.return_value = JobState.FAILED + + with self.assertRaises(AirflowException) as context: + self.operator.execute(Context()) + self.assertTrue("Job 123 failed." in str(context.exception)) + + @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock) + def test_get_current_status(self, mock_hook_property): + mock_hook = MagicMock() + mock_job_status = MagicMock(state=JobState.SUCCEEDED) + mock_hook.get_job_status.return_value = mock_job_status + mock_hook_property.return_value = mock_hook + + # Call the method to test + status = self.operator.get_current_status('123') + + # Verify the result + self.assertEqual(status, 'SUCCEEDED') + + # Ensure the mock was called correctly + mock_hook.get_job_status.assert_called_once_with(job_id='123') + + @patch('airflow.models.BaseOperator.defer') + @patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock) + def test_defer_job_polling(self, mock_hook, mock_defer): + # Mock the submit_job method to return a job ID + mock_hook.submit_job.return_value = '123' + # Mock the get_job_status method to return a starting state + mock_hook.get_job_status.return_value.state = JobState.STARTING + + # Call the execute method which internally calls process_job_status and defer_job_polling + self.operator.execute(Context()) + + # Check that the defer method was called with the correct arguments + mock_defer.assert_called_once() + args, kwargs = mock_defer.call_args + self.assertIsInstance(kwargs['trigger'], AnyscaleJobTrigger) + self.assertEqual(kwargs['trigger'].job_id, '123') + self.assertEqual(kwargs['trigger'].conn_id, 'test_conn') + self.assertEqual(kwargs['method_name'], 'execute_complete') class TestRolloutAnyscaleService(unittest.TestCase): @@ -152,7 +243,36 @@ def test_execute_complete_success(self, mock_hook): self.operator.execute_complete(Context(), event) self.assertEqual(self.operator.service_params['name'], 'test_service') + @patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook', new_callable=PropertyMock) + def test_check_anyscale_hook(self, mock_hook_property): + hook = self.operator.hook + mock_hook_property.assert_called_once() + + def test_no_service_name(self): + with self.assertRaises(ValueError) as cm: + RolloutAnyscaleService( + conn_id='test_conn', + name='', # No service name + image_uri='test_image_uri', + working_dir='/test/dir', + applications=[{'name': 'app1', 'import_path': 'module.optional_submodule:app'}], + compute_config='config123', + task_id='rollout_service_test' + ) + self.assertIn("Service name is required", str(cm.exception)) + def test_no_applications(self): + with self.assertRaises(ValueError) as cm: + RolloutAnyscaleService( + conn_id='test_conn', + name='test_service', + image_uri='test_image_uri', + working_dir='/test/dir', + applications=[], # No applications + compute_config='config123', + task_id='rollout_service_test' + ) + self.assertIn("At least one application must be specified", str(cm.exception)) if __name__ == '__main__': unittest.main() diff --git a/tests/triggers/test_anyscale_triggers.py b/tests/triggers/test_anyscale_triggers.py index 4513c00..a47544c 100644 --- a/tests/triggers/test_anyscale_triggers.py +++ b/tests/triggers/test_anyscale_triggers.py @@ -71,8 +71,6 @@ def test_get_current_status(self, mock_get_job_status): # Call the method to test status = trigger.get_current_status('123') - - print(status) # Verify the result self.assertEqual(status, 'SUCCEEDED') @@ -215,14 +213,14 @@ def test_serialize(self): self.assertEqual(result, expected_output) @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') - def test_get_current_status(self, mock_get_service_status): + def test_get_current_status_canary_0_percent(self, mock_get_service_status): # Mock the return value of get_service_status mock_service_status = MagicMock() mock_service_status.state = ServiceState.RUNNING mock_service_status.canary_version.state = ServiceState.RUNNING mock_get_service_status.return_value = mock_service_status - # Initialize the trigger + # Initialize the trigger with canary_percent set to 0.0 trigger = AnyscaleServiceTrigger(conn_id='default_conn', service_name="AstroService", expected_state=ServiceState.RUNNING, @@ -235,15 +233,39 @@ def test_get_current_status(self, mock_get_service_status): # Call the method to test status = trigger.get_current_status('AstroService') - # Print the result - print(status) - # Verify the result self.assertEqual(status, 'RUNNING') # Ensure the mock was called correctly mock_get_service_status.assert_called_once_with('AstroService') + @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') + def test_get_current_status_canary_100_percent(self, mock_get_service_status): + # Mock the return value of get_service_status + mock_service_status = MagicMock() + mock_service_status.state = ServiceState.TERMINATED + mock_service_status.canary_version.state = ServiceState.RUNNING + mock_get_service_status.return_value = mock_service_status + + # Initialize the trigger with canary_percent set to 100.0 + trigger = AnyscaleServiceTrigger(conn_id='default_conn', + service_name="AstroService", + expected_state=ServiceState.RUNNING, + canary_percent=100.0) + + # Mock the hook property to return our mocked hook + with patch.object(AnyscaleServiceTrigger, 'hook', new_callable=PropertyMock) as mock_hook: + mock_hook.return_value.get_service_status = mock_get_service_status + + # Call the method to test + status = trigger.get_current_status('AstroService') + + # Verify the result + self.assertEqual(status, 'TERMINATED') + + # Ensure the mock was called correctly + mock_get_service_status.assert_called_once_with('AstroService') + if __name__ == '__main__':