Skip to content

Commit

Permalink
hooks updated
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 13, 2024
1 parent 0e52c90 commit 70b1841
Showing 1 changed file with 70 additions and 54 deletions.
124 changes: 70 additions & 54 deletions tests/hooks/test_anyscale_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def test_successful_initialization(self, mock_anyscale, mock_get_connection):
)
hook = AnyscaleHook()
assert hook.get_connection('anyscale_default').password == API_KEY

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection')
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_init_with_env_token(self, mock_anyscale, mock_get_connection):
with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}):
mock_get_connection.return_value = Connection(
conn_id='anyscale_default',
conn_type='http',
host='localhost',
password=None, # No password in connection
extra=json.dumps({})
)
# Mock the Anyscale class to return an instance with the expected auth_token
mock_instance = mock_anyscale.return_value
mock_instance.auth_token = API_KEY

hook = AnyscaleHook()
assert hook.sdk.auth_token == API_KEY

@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_submit_job(self, mock_anyscale):
Expand Down Expand Up @@ -113,64 +131,89 @@ def test_deploy_service_error(self, mock_anyscale):
mock_sdk_instance.service.deploy.assert_called_once_with(config=service_config, in_place=False, canary_percent=10, max_surge_percent=20)
assert str(exc.value) == "Deploy service failed"

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status')
def test_get_job_status(self, mock_get_job_status):
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_get_job_status(self, mock_anyscale):
job_config = JobConfig(name="test_job", entrypoint="python script.py")
mock_get_job_status.return_value = JobStatus(id="test_job_id", name="test_job_id",
config=job_config, state=JobState.SUCCEEDED,
runs=[JobRunStatus(name="test", state=JobState.SUCCEEDED)])

# Create a mock SDK instance with a mock job status method
mock_sdk_instance = mock_anyscale.return_value
mock_sdk_instance.job.status.return_value = JobStatus(
id="test_job_id",
name="test_job",
config=job_config,
state=JobState.SUCCEEDED,
runs=[JobRunStatus(name="test", state=JobState.SUCCEEDED)]
)

# Patch the instance's sdk attribute directly
self.hook.sdk = mock_sdk_instance

result = self.hook.get_job_status("test_job_id")

mock_get_job_status.assert_called_once_with("test_job_id")
mock_sdk_instance.job.status.assert_called_once_with(job_id="test_job_id")
assert result.id == "test_job_id"
assert result.state == JobState.SUCCEEDED

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_job_status')
def test_get_job_status_error(self, mock_get_job_status):
mock_get_job_status.side_effect = AirflowException("Get job status failed")
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_get_job_status_error(self, mock_anyscale):
# Create a mock SDK instance with a mock job status method
mock_sdk_instance = mock_anyscale.return_value
mock_sdk_instance.job.status.side_effect = AirflowException("Get job status failed")

# Patch the instance's sdk attribute directly
self.hook.sdk = mock_sdk_instance

with pytest.raises(AirflowException) as exc:
self.hook.get_job_status("test_job_id")

mock_sdk_instance.job.status.assert_called_once_with(job_id="test_job_id")
assert str(exc.value) == "Get job status failed"

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_job')
def test_terminate_job(self, mock_terminate_job):
mock_terminate_job.return_value = True
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_terminate_job(self, mock_anyscale):
mock_sdk_instance = mock_anyscale.return_value
mock_sdk_instance.job.terminate.return_value = None
self.hook.sdk = mock_sdk_instance

result = self.hook.terminate_job("test_job_id", time_delay=1)

mock_terminate_job.assert_called_once_with("test_job_id", time_delay=1)
mock_sdk_instance.job.terminate.assert_called_once_with(name="test_job_id")
assert result is True

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_job')
def test_terminate_job_error(self, mock_terminate_job):
mock_terminate_job.side_effect = AirflowException("Job termination failed")
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_terminate_job_error(self, mock_anyscale):
mock_sdk_instance = mock_anyscale.return_value
mock_sdk_instance.job.terminate.side_effect = Exception("Terminate job failed")
self.hook.sdk = mock_sdk_instance

with pytest.raises(AirflowException) as exc:
self.hook.terminate_job("test_job_id", time_delay=1)

assert str(exc.value) == "Job termination failed"
mock_sdk_instance.job.terminate.assert_called_once_with(name="test_job_id")
assert str(exc.value) == "Job termination failed with error: Terminate job failed"

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_service')
def test_terminate_service(self, mock_terminate_service):
mock_terminate_service.return_value = True
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_terminate_service(self, mock_anyscale):
mock_sdk_instance = mock_anyscale.return_value
mock_sdk_instance.service.terminate.return_value = None
self.hook.sdk = mock_sdk_instance

result = self.hook.terminate_service("test_service_id", time_delay=1)

mock_terminate_service.assert_called_once_with("test_service_id", time_delay=1)
mock_sdk_instance.service.terminate.assert_called_once_with(name="test_service_id")
assert result is True

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.terminate_service')
def test_terminate_service_error(self, mock_terminate_service):
mock_terminate_service.side_effect = AirflowException("Service termination failed")
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_terminate_service_error(self, mock_anyscale):
mock_sdk_instance = mock_anyscale.return_value
mock_sdk_instance.service.terminate.side_effect = Exception("Terminate service failed")
self.hook.sdk = mock_sdk_instance

with pytest.raises(AirflowException) as exc:
self.hook.terminate_service("test_service_id", time_delay=1)

assert str(exc.value) == "Service termination failed"

mock_sdk_instance.service.terminate.assert_called_once_with(name="test_service_id")
assert str(exc.value) == "Service termination failed with error: Terminate service failed"
@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs')
def test_get_logs(self, mock_get_logs):
mock_get_logs.return_value = "job logs"
Expand All @@ -189,15 +232,6 @@ def test_get_logs_empty(self, mock_get_logs):
mock_get_logs.assert_called_once_with("test_job_id")
assert result == ""

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_logs')
def test_get_logs_error(self, mock_get_logs):
mock_get_logs.side_effect = AirflowException("Failed to get logs")

with pytest.raises(AirflowException) as exc:
self.hook.get_logs("test_job_id")

assert str(exc.value) == "Failed to get logs"

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_service_status')
def test_get_service_status(self, mock_get_service_status):
mock_service_status = ServiceStatus(id="test_service_id", name="test_service", query_url="http://example.com", state=ServiceState.RUNNING)
Expand All @@ -220,24 +254,6 @@ def test_get_service_status_error(self, mock_get_service_status):

assert str(exc.value) == "Get service status failed"

@patch('anyscale_provider.hooks.anyscale.AnyscaleHook.get_connection')
@patch("anyscale_provider.hooks.anyscale.Anyscale")
def test_init_with_env_token(self, mock_anyscale, mock_get_connection):
with mock.patch.dict("os.environ", {"ANYSCALE_CLI_TOKEN": API_KEY}):
mock_get_connection.return_value = Connection(
conn_id='anyscale_default',
conn_type='http',
host='localhost',
password=None, # No password in connection
extra=json.dumps({})
)
# Mock the Anyscale class to return an instance with the expected auth_token
mock_instance = mock_anyscale.return_value
mock_instance.auth_token = API_KEY

hook = AnyscaleHook()
assert hook.sdk.auth_token == API_KEY

@patch("anyscale_provider.hooks.anyscale.time.sleep", return_value=None)
def test_terminate_job_with_delay(self, mock_sleep):
with patch.object(self.hook.sdk.job, 'terminate', return_value=None) as mock_terminate:
Expand Down

0 comments on commit 70b1841

Please sign in to comment.