diff --git a/tests/hooks/test_anyscale_hook.py b/tests/hooks/test_anyscale_hook.py index 4787179..3bef42c 100644 --- a/tests/hooks/test_anyscale_hook.py +++ b/tests/hooks/test_anyscale_hook.py @@ -4,9 +4,8 @@ from unittest.mock import patch, MagicMock from airflow.exceptions import AirflowException from airflow.models import Connection -from anyscale import Anyscale from anyscale.job.models import JobConfig, JobStatus, JobState, JobRunStatus -from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceVersionState, ServiceState +from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceState from anyscale_provider.hooks.anyscale import AnyscaleHook API_KEY = "api_key_value" @@ -179,4 +178,56 @@ def test_get_service_status_error(self, mock_get_service_status): with pytest.raises(AirflowException) as exc: self.hook.get_service_status("test_service_name") - assert str(exc.value) == "Get service status failed" \ No newline at end of file + assert str(exc.value) == "Get service status failed" + + @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection') + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_init_with_env_token(self, mock_anyscale, mock_get_connection): + with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}): + mock_get_connection.return_value = Connection( + conn_id='anyscale_default', + conn_type='http', + host='localhost', + password=None, # No password in connection + extra=json.dumps({}) + ) + # Mock the Anyscale class to return an instance with the expected auth_token + mock_instance = mock_anyscale.return_value + mock_instance.auth_token = API_KEY + + hook = AnyscaleHook() + assert hook.sdk.auth_token == API_KEY + + @patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None) + def test_terminate_job_with_delay(self, mock_sleep): + with patch.object(self.hook.sdk.job, 'terminate', return_value=None) as mock_terminate: + result = self.hook.terminate_job("test_job_id", time_delay=1) + mock_terminate.assert_called_once_with(name="test_job_id") + mock_sleep.assert_called_once_with(1) + assert result is True + + @patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None) + def test_terminate_service_with_delay(self, mock_sleep): + with patch.object(self.hook.sdk.service, 'terminate', return_value=None) as mock_terminate: + result = self.hook.terminate_service("test_service_id", time_delay=1) + mock_terminate.assert_called_once_with(name="test_service_id") + mock_sleep.assert_called_once_with(1) + assert result is True + + @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs') + def test_get_logs_empty(self, mock_get_logs): + mock_get_logs.return_value = "" + + result = self.hook.get_logs("test_job_id") + + mock_get_logs.assert_called_once_with("test_job_id") + assert result == "" + + @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs') + def test_get_logs_error(self, mock_get_logs): + mock_get_logs.side_effect = AirflowException("Failed to get logs") + + with pytest.raises(AirflowException) as exc: + self.hook.get_logs("test_job_id") + + assert str(exc.value) == "Failed to get logs" \ No newline at end of file