Skip to content

Commit

Permalink
get_current_status updated
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 13, 2024
1 parent 9dd8c7f commit 77a4e21
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import patch, MagicMock, AsyncMock
from unittest.mock import patch, MagicMock, PropertyMock
import asyncio
from datetime import datetime
import os
Expand All @@ -8,7 +8,7 @@
from pathlib import Path
from airflow.exceptions import AirflowNotFoundException

from anyscale.job.models import JobState, JobStatus
from anyscale.job.models import JobState, JobStatus, JobConfig, JobRunStatus
from anyscale.service.models import ServiceState, ServiceStatus

from anyscale_provider.hooks.anyscale import AnyscaleHook
Expand Down Expand Up @@ -60,20 +60,25 @@ async def test_run_exception(self, mock_is_terminal_status):
self.assertIn('Error occurred', events[0]['message'])

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status')
async def test_get_current_status(self, mock_get_job_status):
mock_get_job_status.return_value = MagicMock(state=JobStatus(id="test_job_state",
name="123",
state = JobState.SUCCEEDED))
def test_get_current_status(self, mock_get_job_status):
mock_get_job_status.return_value = MagicMock(state=JobState.SUCCEEDED)
trigger = AnyscaleJobTrigger(conn_id='default_conn',
job_id='123',
job_start_time=datetime.now().timestamp())
events = []
async for event in trigger.run():
events.append(event)

status = self.trigger.get_current_status('123')
self.assertEqual(status, JobState.SUCCEEDED)
mock_get_job_status.assert_called_once_with(job_id='123')
# Mock the hook property to return our mocked hook
with patch.object(AnyscaleJobTrigger, 'hook', new_callable=PropertyMock) as mock_hook:
mock_hook.return_value.get_job_status = mock_get_job_status

# Call the method to test
status = trigger.get_current_status('123')

print(status)

# Verify the result
self.assertEqual(status, 'SUCCEEDED')

# Ensure the mock was called correctly
mock_get_job_status.assert_called_once_with(job_id='123')

async def test_run_no_job_id_provided(self):
trigger = AnyscaleJobTrigger(conn_id='default_conn',
Expand Down Expand Up @@ -210,22 +215,34 @@ def test_serialize(self):
self.assertEqual(result, expected_output)

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status')
async def test_get_current_status(self, mock_get_service_status):
mock_get_service_status.return_value = MagicMock(state=ServiceStatus(name="AstroService",
id="123",
state=ServiceState.RUNNING,
query_url="https://sample-url"))
def test_get_current_status(self, mock_get_service_status):
# Mock the return value of get_service_status
mock_service_status = MagicMock()
mock_service_status.state = ServiceState.RUNNING
mock_service_status.canary_version.state = ServiceState.RUNNING
mock_get_service_status.return_value = mock_service_status

# Initialize the trigger
trigger = AnyscaleServiceTrigger(conn_id='default_conn',
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=0.0)
events = []
async for event in trigger.run():
events.append(event)

status = self.trigger.get_current_status('AstroService')
self.assertEqual(status, ServiceState.RUNNING)
mock_get_service_status.assert_called_once_with(service_name='AstroService')
# Mock the hook property to return our mocked hook
with patch.object(AnyscaleServiceTrigger, 'hook', new_callable=PropertyMock) as mock_hook:
mock_hook.return_value.get_service_status = mock_get_service_status

# Call the method to test
status = trigger.get_current_status('AstroService')

# Print the result
print(status)

# Verify the result
self.assertEqual(status, 'RUNNING')

# Ensure the mock was called correctly
mock_get_service_status.assert_called_once_with('AstroService')



Expand Down

0 comments on commit 77a4e21

Please sign in to comment.