diff --git a/tests/hooks/test_anyscale_hook.py b/tests/hooks/test_anyscale_hook.py index ca06b94..0bd534f 100644 --- a/tests/hooks/test_anyscale_hook.py +++ b/tests/hooks/test_anyscale_hook.py @@ -50,6 +50,24 @@ def test_successful_initialization(self, mock_anyscale, mock_get_connection): ) hook = AnyscaleHook() assert hook.get_connection('anyscale_default').password == API_KEY + + @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection') + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_init_with_env_token(self, mock_anyscale, mock_get_connection): + with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}): + mock_get_connection.return_value = Connection( + conn_id='anyscale_default', + conn_type='http', + host='localhost', + password=None, # No password in connection + extra=json.dumps({}) + ) + # Mock the Anyscale class to return an instance with the expected auth_token + mock_instance = mock_anyscale.return_value + mock_instance.auth_token = API_KEY + + hook = AnyscaleHook() + assert hook.sdk.auth_token == API_KEY @patch("anyscale_provider.hooks.anyscale.Anyscale") def test_submit_job(self, mock_anyscale): @@ -113,64 +131,89 @@ def test_deploy_service_error(self, mock_anyscale): mock_sdk_instance.service.deploy.assert_called_once_with(config=service_config, in_place=False, canary_percent=10, max_surge_percent=20) assert str(exc.value) == "Deploy service failed" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status') - def test_get_job_status(self, mock_get_job_status): + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_get_job_status(self, mock_anyscale): job_config = JobConfig(name="test_job", entrypoint="python script.py") - mock_get_job_status.return_value = JobStatus(id="test_job_id", name="test_job_id", - config=job_config, state=JobState.SUCCEEDED, - runs=[JobRunStatus(name="test", state=JobState.SUCCEEDED)]) + + # Create a mock SDK instance with a mock job status method + mock_sdk_instance = mock_anyscale.return_value + mock_sdk_instance.job.status.return_value = JobStatus( + id="test_job_id", + name="test_job", + config=job_config, + state=JobState.SUCCEEDED, + runs=[JobRunStatus(name="test", state=JobState.SUCCEEDED)] + ) + + # Patch the instance's sdk attribute directly + self.hook.sdk = mock_sdk_instance result = self.hook.get_job_status("test_job_id") - mock_get_job_status.assert_called_once_with("test_job_id") + mock_sdk_instance.job.status.assert_called_once_with(job_id="test_job_id") assert result.id == "test_job_id" assert result.state == JobState.SUCCEEDED - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status') - def test_get_job_status_error(self, mock_get_job_status): - mock_get_job_status.side_effect = AirflowException("Get job status failed") + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_get_job_status_error(self, mock_anyscale): + # Create a mock SDK instance with a mock job status method + mock_sdk_instance = mock_anyscale.return_value + mock_sdk_instance.job.status.side_effect = AirflowException("Get job status failed") + + # Patch the instance's sdk attribute directly + self.hook.sdk = mock_sdk_instance with pytest.raises(AirflowException) as exc: self.hook.get_job_status("test_job_id") + mock_sdk_instance.job.status.assert_called_once_with(job_id="test_job_id") assert str(exc.value) == "Get job status failed" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_job') - def test_terminate_job(self, mock_terminate_job): - mock_terminate_job.return_value = True + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_terminate_job(self, mock_anyscale): + mock_sdk_instance = mock_anyscale.return_value + mock_sdk_instance.job.terminate.return_value = None + self.hook.sdk = mock_sdk_instance result = self.hook.terminate_job("test_job_id", time_delay=1) - mock_terminate_job.assert_called_once_with("test_job_id", time_delay=1) + mock_sdk_instance.job.terminate.assert_called_once_with(name="test_job_id") assert result is True - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_job') - def test_terminate_job_error(self, mock_terminate_job): - mock_terminate_job.side_effect = AirflowException("Job termination failed") + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_terminate_job_error(self, mock_anyscale): + mock_sdk_instance = mock_anyscale.return_value + mock_sdk_instance.job.terminate.side_effect = Exception("Terminate job failed") + self.hook.sdk = mock_sdk_instance with pytest.raises(AirflowException) as exc: self.hook.terminate_job("test_job_id", time_delay=1) - assert str(exc.value) == "Job termination failed" + mock_sdk_instance.job.terminate.assert_called_once_with(name="test_job_id") + assert str(exc.value) == "Job termination failed with error: Terminate job failed" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_service') - def test_terminate_service(self, mock_terminate_service): - mock_terminate_service.return_value = True + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_terminate_service(self, mock_anyscale): + mock_sdk_instance = mock_anyscale.return_value + mock_sdk_instance.service.terminate.return_value = None + self.hook.sdk = mock_sdk_instance result = self.hook.terminate_service("test_service_id", time_delay=1) - mock_terminate_service.assert_called_once_with("test_service_id", time_delay=1) + mock_sdk_instance.service.terminate.assert_called_once_with(name="test_service_id") assert result is True - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_service') - def test_terminate_service_error(self, mock_terminate_service): - mock_terminate_service.side_effect = AirflowException("Service termination failed") + @patch("anyscale_provider.hooks.anyscale.Anyscale") + def test_terminate_service_error(self, mock_anyscale): + mock_sdk_instance = mock_anyscale.return_value + mock_sdk_instance.service.terminate.side_effect = Exception("Terminate service failed") + self.hook.sdk = mock_sdk_instance with pytest.raises(AirflowException) as exc: self.hook.terminate_service("test_service_id", time_delay=1) - - assert str(exc.value) == "Service termination failed" - + mock_sdk_instance.service.terminate.assert_called_once_with(name="test_service_id") + assert str(exc.value) == "Service termination failed with error: Terminate service failed" + @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs') def test_get_logs(self, mock_get_logs): mock_get_logs.return_value = "job logs" @@ -189,15 +232,6 @@ def test_get_logs_empty(self, mock_get_logs): mock_get_logs.assert_called_once_with("test_job_id") assert result == "" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs') - def test_get_logs_error(self, mock_get_logs): - mock_get_logs.side_effect = AirflowException("Failed to get logs") - - with pytest.raises(AirflowException) as exc: - self.hook.get_logs("test_job_id") - - assert str(exc.value) == "Failed to get logs" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status') def test_get_service_status(self, mock_get_service_status): mock_service_status = ServiceStatus(id="test_service_id", name="test_service", query_url="http://example.com", state=ServiceState.RUNNING) @@ -220,24 +254,6 @@ def test_get_service_status_error(self, mock_get_service_status): assert str(exc.value) == "Get service status failed" - @patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection') - @patch("anyscale_provider.hooks.anyscale.Anyscale") - def test_init_with_env_token(self, mock_anyscale, mock_get_connection): - with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}): - mock_get_connection.return_value = Connection( - conn_id='anyscale_default', - conn_type='http', - host='localhost', - password=None, # No password in connection - extra=json.dumps({}) - ) - # Mock the Anyscale class to return an instance with the expected auth_token - mock_instance = mock_anyscale.return_value - mock_instance.auth_token = API_KEY - - hook = AnyscaleHook() - assert hook.sdk.auth_token == API_KEY - @patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None) def test_terminate_job_with_delay(self, mock_sleep): with patch.object(self.hook.sdk.job, 'terminate', return_value=None) as mock_terminate: