Skip to content

Commit

Permalink
unit tests updated
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 12, 2024
1 parent a0f4a89 commit fd60f21
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
import unittest
from unittest.mock import patch, MagicMock
import asyncio
from datetime import datetime

from anyscale.job.models import JobState
from anyscale.service.models import ServiceState

from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger,AnyscaleServiceTrigger
from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger

class TestAnyscaleJobTrigger(unittest.TestCase):
def setUp(self):
self.trigger = AnyscaleJobTrigger(conn_id='default_conn',
job_id='123',
job_start_time=datetime.now())
job_start_time=datetime.now().timestamp())

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status')
def test_is_terminal_status(self, mock_get_status):
mock_get_status.return_value = 'COMPLETED'
self.assertTrue(self.trigger.is_terminal_status('123'))

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status')
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status')
def test_is_not_terminal_status(self, mock_get_status):
mock_get_status.return_value = False
mock_get_status.return_value = 'RUNNING'
self.assertFalse(self.trigger.is_terminal_status('123'))

@patch('asyncio.sleep', return_value=None)
@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.get_current_status', side_effect=['RUNNING', 'RUNNING', 'COMPLETED'])
async def test_run_successful_completion(self, mock_get_status, mock_sleep):
events = []
async for event in self.trigger.run():
self.assertIn('status', event)
self.assertEqual(event['status'], 'COMPLETED')
events.append(event)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], 'COMPLETED')

@patch('time.time', side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout
@patch('asyncio.sleep', return_value=None)
async def test_run_timeout(self, mock_sleep, mock_time):
events = []
async for event in self.trigger.run():
events.append(event)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], 'timeout')

@patch('anyscale_provider.triggers.anyscale.AnyscaleJobTrigger.is_terminal_status', side_effect=Exception("Error occurred"))
async def test_run_exception(self, mock_is_terminal_status):
events = []
async for event in self.trigger.run():
self.assertEqual(event['status'], 'timeout')
events.append(event)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], JobState.FAILED)
self.assertIn('Error occurred', events[0]['message'])

class TestAnyscaleServiceTrigger(unittest.TestCase):
def setUp(self):
Expand All @@ -50,15 +66,31 @@ def test_check_current_status(self, mock_get_status):
@patch('asyncio.sleep', return_value=None)
@patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.get_current_status', side_effect=['STARTING', 'UPDATING', 'RUNNING'])
async def test_run_successful(self, mock_get_status, mock_sleep):
events = []
async for event in self.trigger.run():
self.assertEqual(event['status'], 'success')
self.assertIn('Service deployment succeeded', event['message'])
events.append(event)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], ServiceState.RUNNING)
self.assertIn('Service deployment succeeded', events[0]['message'])

@patch('time.time', side_effect=[100, 200, 300, 400, 10000]) # Simulating time passing and timeout
@patch('asyncio.sleep', return_value=None)
async def test_run_timeout(self, mock_sleep, mock_time):
events = []
async for event in self.trigger.run():
events.append(event)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], ServiceState.UNKNOWN)
self.assertIn('did not reach RUNNING within the timeout period', events[0]['message'])

@patch('anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger.check_current_status', side_effect=Exception("Error occurred"))
async def test_run_exception(self, mock_check_current_status):
events = []
async for event in self.trigger.run():
self.assertEqual(event['status'], 'timeout')
events.append(event)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], ServiceState.SYSTEM_FAILURE)
self.assertIn('Error occurred', events[0]['message'])

if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit fd60f21

Please sign in to comment.