Skip to content

Commit

Permalink
added additional unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 12, 2024
1 parent a99148e commit 3a25891
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions tests/hooks/test_anyscale_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from unittest.mock import patch, MagicMock
from airflow.exceptions import AirflowException
from airflow.models import Connection
from anyscale import Anyscale
from anyscale.job.models import JobConfig, JobStatus, JobState, JobRunStatus
from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceVersionState, ServiceState
from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceState
from anyscale_provider.hooks.anyscale import AnyscaleHook

API_KEY = "api_key_value"
Expand Down Expand Up @@ -179,4 +178,56 @@ def test_get_service_status_error(self, mock_get_service_status):
with pytest.raises(AirflowException) as exc:
self.hook.get_service_status("test_service_name")

assert str(exc.value) == "Get service status failed"
assert str(exc.value) == "Get service status failed"

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection')
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_init_with_env_token(self, mock_anyscale, mock_get_connection):
with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}):
mock_get_connection.return_value = Connection(
conn_id='anyscale_default',
conn_type='http',
host='localhost',
password=None, # No password in connection
extra=json.dumps({})
)
# Mock the Anyscale class to return an instance with the expected auth_token
mock_instance = mock_anyscale.return_value
mock_instance.auth_token = API_KEY

hook = AnyscaleHook()
assert hook.sdk.auth_token == API_KEY

@patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None)
def test_terminate_job_with_delay(self, mock_sleep):
with patch.object(self.hook.sdk.job, 'terminate', return_value=None) as mock_terminate:
result = self.hook.terminate_job("test_job_id", time_delay=1)
mock_terminate.assert_called_once_with(name="test_job_id")
mock_sleep.assert_called_once_with(1)
assert result is True

@patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None)
def test_terminate_service_with_delay(self, mock_sleep):
with patch.object(self.hook.sdk.service, 'terminate', return_value=None) as mock_terminate:
result = self.hook.terminate_service("test_service_id", time_delay=1)
mock_terminate.assert_called_once_with(name="test_service_id")
mock_sleep.assert_called_once_with(1)
assert result is True

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs')
def test_get_logs_empty(self, mock_get_logs):
mock_get_logs.return_value = ""

result = self.hook.get_logs("test_job_id")

mock_get_logs.assert_called_once_with("test_job_id")
assert result == ""

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs')
def test_get_logs_error(self, mock_get_logs):
mock_get_logs.side_effect = AirflowException("Failed to get logs")

with pytest.raises(AirflowException) as exc:
self.hook.get_logs("test_job_id")

assert str(exc.value) == "Failed to get logs"

0 comments on commit 3a25891

Please sign in to comment.