Skip to content

Commit

Permalink
anyscale job trigger tests updated
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 13, 2024
1 parent d7c5033 commit 461d219
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from datetime import datetime
import os
import pytest
from typing import Any, Dict, AsyncIterator, Tuple, Optional
from pathlib import Path
from airflow.models import DagBag, Connection
from airflow.utils.db import create_default_connections
from airflow.utils.session import provide_session, create_session
from airflow.exceptions import AirflowNotFoundException

from anyscale.job.models import JobState
from anyscale.service.models import ServiceState

from anyscale_provider.hooks.anyscale import AnyscaleHook
from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger
from airflow.triggers.base import TriggerEvent
from airflow.models.connection import Connection
Expand Down Expand Up @@ -60,8 +60,15 @@ async def test_run_exception(self, mock_is_terminal_status):
self.assertIn('Error occurred', events[0]['message'])

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status')
def test_get_current_status(self, mock_get_job_status):
async def test_get_current_status(self, mock_get_job_status):
mock_get_job_status.return_value = MagicMock(state=JobState.SUCCEEDED)
trigger = AnyscaleJobTrigger(conn_id='default_conn',
job_id='',
job_start_time=datetime.now().timestamp())
events = []
async for event in trigger.run():
events.append(event)

status = self.trigger.get_current_status('123')
self.assertEqual(status, JobState.SUCCEEDED)
mock_get_job_status.assert_called_once_with(job_id='123')
Expand All @@ -76,6 +83,45 @@ async def test_run_no_job_id_provided(self):
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['status'], 'error')
self.assertIn('No job_id provided', events[0]['message'])

@patch('airflow.models.connection.Connection.get_connection_from_secrets')
def test_hook_method(self, mock_get_connection):
# Configure the mock to raise AirflowNotFoundException
mock_get_connection.side_effect = AirflowNotFoundException("The conn_id `default_conn` isn't defined")

trigger = AnyscaleJobTrigger(conn_id='default_conn',
job_id='123',
job_start_time=datetime.now().timestamp())

with self.assertRaises(AirflowNotFoundException) as context:
result = trigger.hook()

self.assertIn("The conn_id `default_conn` isn't defined", str(context.exception))

def test_serialize(self):
time = datetime.now().timestamp()
trigger = AnyscaleJobTrigger(conn_id='default_conn',
job_id='123',
job_start_time=time)

result = trigger.serialize()
expected_output = ("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", {
"conn_id": 'default_conn',
"job_id": '123',
"job_start_time": time,
"poll_interval": 60,
"timeout": 3600
})

# Check if the result is a tuple
self.assertTrue(isinstance(result, tuple))

# Check if the tuple contains a string and a dictionary
self.assertTrue(isinstance(result[0], str))
self.assertTrue(isinstance(result[1], dict))

# Check if the result matches the expected output
self.assertEqual(result, expected_output)

class TestAnyscaleServiceTrigger(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 461d219

Please sign in to comment.