Skip to content

Commit

Permalink
job trigger updated
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 13, 2024
1 parent 461d219 commit 3b9ce1b
Showing 1 changed file with 106 additions and 1 deletion.
107 changes: 106 additions & 1 deletion tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, AsyncMock
import asyncio
from datetime import datetime
import os
Expand Down Expand Up @@ -122,6 +122,69 @@ def test_serialize(self):

# Check if the result matches the expected output
self.assertEqual(result, expected_output)

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', new_callable=AsyncMock)
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.hook')
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status')
async def test_run_method_no_job_id(self, mock_get_current_status, mock_hook, mock_is_terminal_status):
trigger = AnyscaleJobTrigger(conn_id='default_conn', job_id=None, job_start_time=datetime.now().timestamp())

events = []
async for event in trigger.run():
events.append(event)

self.assertEqual(len(events), 1)
self.assertEqual(events[0]["status"], "error")
self.assertEqual(events[0]["message"], "No job_id provided to async trigger")

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', new_callable=AsyncMock)
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.hook')
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status')
async def test_run_method_timeout(self, mock_get_current_status, mock_hook, mock_is_terminal_status):
mock_is_terminal_status.return_value = False
trigger = AnyscaleJobTrigger(conn_id='default_conn', job_id='123', job_start_time=datetime.now().timestamp(), poll_interval=1, timeout=1)

events = []
async for event in trigger.run():
events.append(event)

self.assertEqual(len(events), 1)
self.assertEqual(events[0]["status"], "timeout")
self.assertIn("Timeout waiting for job 123 to complete.", events[0]["message"])

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', new_callable=AsyncMock)
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.hook')
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status')
async def test_run_method_success(self, mock_get_current_status, mock_hook, mock_is_terminal_status):
mock_is_terminal_status.side_effect = [False, True]
mock_get_current_status.return_value = "success"

mock_hook_instance = MagicMock()
mock_hook.return_value = mock_hook_instance
mock_hook_instance.get_logs.return_value = "log line 1\nlog line 2"

trigger = AnyscaleJobTrigger(conn_id='default_conn', job_id='123', job_start_time=datetime.now().timestamp(), poll_interval=1, timeout=10)

events = []
async for event in trigger.run():
events.append(event)

self.assertEqual(events[-1]["status"], "success")
self.assertIn("Job 123 completed with status success.", events[-1]["message"])

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', new_callable=AsyncMock)
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.hook')
async def test_run_method_exception(self, mock_hook, mock_is_terminal_status):
mock_is_terminal_status.side_effect = Exception("Test exception")

trigger = AnyscaleJobTrigger(conn_id='default_conn', job_id='123', job_start_time=datetime.now().timestamp(), poll_interval=1, timeout=10)

events = []
async for event in trigger.run():
events.append(event)

self.assertEqual(events[-1]["status"], "failed")
self.assertIn("An error occurred while polling for job status.", events[-1]["message"])

class TestAnyscaleServiceTrigger(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -163,6 +226,48 @@ async def test_run_exception(self, mock_check_current_status):
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], ServiceState.SYSTEM_FAILURE)
self.assertIn('Error occurred', events[0]['message'])

@patch('airflow.models.connection.Connection.get_connection_from_secrets')
def test_hook_method(self, mock_get_connection):
# Configure the mock to raise AirflowNotFoundException
mock_get_connection.side_effect = AirflowNotFoundException("The conn_id `default_conn` isn't defined")

trigger = AnyscaleServiceTrigger(conn_id='default_conn',
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=0.0)

with self.assertRaises(AirflowNotFoundException) as context:
result = trigger.hook()

self.assertIn("The conn_id `default_conn` isn't defined", str(context.exception))

def test_serialize(self):

trigger = AnyscaleServiceTrigger(conn_id='default_conn',
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=0.0)

result = trigger.serialize()
expected_output = ("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger", {
"conn_id": 'default_conn',
"service_name": "AstroService",
"expected_state": ServiceState.RUNNING,
"canary_percent": 0.0,
"poll_interval": 60,
"timeout": 600
})

# Check if the result is a tuple
self.assertTrue(isinstance(result, tuple))

# Check if the tuple contains a string and a dictionary
self.assertTrue(isinstance(result[0], str))
self.assertTrue(isinstance(result[1], dict))

# Check if the result matches the expected output
self.assertEqual(result, expected_output)

if __name__ == '__main__':
unittest.main()

0 comments on commit 3b9ce1b

Please sign in to comment.