Skip to content

Commit

Permalink
Added additional unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 13, 2024
1 parent 77a4e21 commit 5d1c284
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 9 deletions.
124 changes: 122 additions & 2 deletions tests/operators/test_anyscale_operators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock,PropertyMock
from airflow.utils.context import Context
from airflow.exceptions import AirflowException, TaskDeferred
from anyscale.job.models import JobState
from anyscale.service.models import ServiceState
from anyscale_provider.operators.anyscale import SubmitAnyscaleJob
from anyscale_provider.operators.anyscale import RolloutAnyscaleService
from anyscale_provider.triggers.anyscale import AnyscaleServiceTrigger
from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger,AnyscaleServiceTrigger


class TestSubmitAnyscaleJob(unittest.TestCase):
Expand Down Expand Up @@ -77,6 +77,97 @@ def test_execute_complete_failure(self, mock_hook):
with self.assertRaises(AirflowException) as context:
self.operator.execute_complete(Context(), event)
self.assertTrue("Job 123 failed with error" in str(context.exception))

def test_no_job_name(self):
with self.assertRaises(AirflowException) as context:
SubmitAnyscaleJob(
conn_id='test_conn',
name='', # No job name
image_uri='test_image_uri',
compute_config={},
working_dir='/test/dir',
entrypoint='test_entrypoint',
task_id='submit_job_test'
)
self.assertTrue("Job name is required." in str(context.exception))

def test_no_entrypoint_provided(self):
with self.assertRaises(AirflowException) as context:
SubmitAnyscaleJob(
conn_id='test_conn',
name='test_job',
image_uri='test_image_uri',
compute_config={},
working_dir='/test/dir',
entrypoint='', # No entrypoint
task_id='submit_job_test'
)
self.assertTrue("Entrypoint must be specified." in str(context.exception))

@patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock)
def test_check_anyscale_hook(self, mock_hook_property):
# Access the hook property
hook = self.operator.hook
# Verify that the hook property was accessed
mock_hook_property.assert_called_once()

@patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock)
def test_execute_with_no_hook(self, mock_hook_property):
# Simulate the hook not being available by raising an AirflowException
mock_hook_property.side_effect = AirflowException("SDK is not available.")

# Execute the operator and expect it to raise an AirflowException
with self.assertRaises(AirflowException) as context:
self.operator.execute(Context())

self.assertTrue("SDK is not available." in str(context.exception))

@patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.get_current_status')
@patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock)
def test_job_state_failed(self, mock_hook, mock_get_status):
job_result_mock = MagicMock()
job_result_mock.id = '123'
mock_hook.submit_job.return_value = '123'
mock_get_status.return_value = JobState.FAILED

with self.assertRaises(AirflowException) as context:
self.operator.execute(Context())
self.assertTrue("Job 123 failed." in str(context.exception))

@patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=PropertyMock)
def test_get_current_status(self, mock_hook_property):
mock_hook = MagicMock()
mock_job_status = MagicMock(state=JobState.SUCCEEDED)
mock_hook.get_job_status.return_value = mock_job_status
mock_hook_property.return_value = mock_hook

# Call the method to test
status = self.operator.get_current_status('123')

# Verify the result
self.assertEqual(status, 'SUCCEEDED')

# Ensure the mock was called correctly
mock_hook.get_job_status.assert_called_once_with(job_id='123')

@patch('airflow.models.BaseOperator.defer')
@patch('anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook', new_callable=MagicMock)
def test_defer_job_polling(self, mock_hook, mock_defer):
# Mock the submit_job method to return a job ID
mock_hook.submit_job.return_value = '123'
# Mock the get_job_status method to return a starting state
mock_hook.get_job_status.return_value.state = JobState.STARTING

# Call the execute method which internally calls process_job_status and defer_job_polling
self.operator.execute(Context())

# Check that the defer method was called with the correct arguments
mock_defer.assert_called_once()
args, kwargs = mock_defer.call_args
self.assertIsInstance(kwargs['trigger'], AnyscaleJobTrigger)
self.assertEqual(kwargs['trigger'].job_id, '123')
self.assertEqual(kwargs['trigger'].conn_id, 'test_conn')
self.assertEqual(kwargs['method_name'], 'execute_complete')


class TestRolloutAnyscaleService(unittest.TestCase):
Expand Down Expand Up @@ -152,7 +243,36 @@ def test_execute_complete_success(self, mock_hook):
self.operator.execute_complete(Context(), event)
self.assertEqual(self.operator.service_params['name'], 'test_service')

@patch('anyscale_provider.operators.anyscale.RolloutAnyscaleService.hook', new_callable=PropertyMock)
def test_check_anyscale_hook(self, mock_hook_property):
hook = self.operator.hook
mock_hook_property.assert_called_once()

def test_no_service_name(self):
with self.assertRaises(ValueError) as cm:
RolloutAnyscaleService(
conn_id='test_conn',
name='', # No service name
image_uri='test_image_uri',
working_dir='/test/dir',
applications=[{'name': 'app1', 'import_path': 'module.optional_submodule:app'}],
compute_config='config123',
task_id='rollout_service_test'
)
self.assertIn("Service name is required", str(cm.exception))

def test_no_applications(self):
with self.assertRaises(ValueError) as cm:
RolloutAnyscaleService(
conn_id='test_conn',
name='test_service',
image_uri='test_image_uri',
working_dir='/test/dir',
applications=[], # No applications
compute_config='config123',
task_id='rollout_service_test'
)
self.assertIn("At least one application must be specified", str(cm.exception))

if __name__ == '__main__':
unittest.main()
36 changes: 29 additions & 7 deletions tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def test_get_current_status(self, mock_get_job_status):

# Call the method to test
status = trigger.get_current_status('123')

print(status)

# Verify the result
self.assertEqual(status, 'SUCCEEDED')
Expand Down Expand Up @@ -215,14 +213,14 @@ def test_serialize(self):
self.assertEqual(result, expected_output)

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status')
def test_get_current_status(self, mock_get_service_status):
def test_get_current_status_canary_0_percent(self, mock_get_service_status):
# Mock the return value of get_service_status
mock_service_status = MagicMock()
mock_service_status.state = ServiceState.RUNNING
mock_service_status.canary_version.state = ServiceState.RUNNING
mock_get_service_status.return_value = mock_service_status

# Initialize the trigger
# Initialize the trigger with canary_percent set to 0.0
trigger = AnyscaleServiceTrigger(conn_id='default_conn',
service_name="AstroService",
expected_state=ServiceState.RUNNING,
Expand All @@ -235,15 +233,39 @@ def test_get_current_status(self, mock_get_service_status):
# Call the method to test
status = trigger.get_current_status('AstroService')

# Print the result
print(status)

# Verify the result
self.assertEqual(status, 'RUNNING')

# Ensure the mock was called correctly
mock_get_service_status.assert_called_once_with('AstroService')

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status')
def test_get_current_status_canary_100_percent(self, mock_get_service_status):
# Mock the return value of get_service_status
mock_service_status = MagicMock()
mock_service_status.state = ServiceState.TERMINATED
mock_service_status.canary_version.state = ServiceState.RUNNING
mock_get_service_status.return_value = mock_service_status

# Initialize the trigger with canary_percent set to 100.0
trigger = AnyscaleServiceTrigger(conn_id='default_conn',
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0)

# Mock the hook property to return our mocked hook
with patch.object(AnyscaleServiceTrigger, 'hook', new_callable=PropertyMock) as mock_hook:
mock_hook.return_value.get_service_status = mock_get_service_status

# Call the method to test
status = trigger.get_current_status('AstroService')

# Verify the result
self.assertEqual(status, 'TERMINATED')

# Ensure the mock was called correctly
mock_get_service_status.assert_called_once_with('AstroService')



if __name__ == '__main__':
Expand Down

0 comments on commit 5d1c284

Please sign in to comment.