From 6456491ac238ede6e9f4b33bea1cf125bb25cab0 Mon Sep 17 00:00:00 2001 From: Ajinkya Bogle <119069756+Ajinkya-25@users.noreply.github.com> Date: Thu, 20 Feb 2025 14:39:45 +0530 Subject: [PATCH] bug fix in inference_endpoint wait function for proper waiting on update (#2867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bug fix in inference_endpoint wait for proper waiting on update * Update src/huggingface_hub/_inference_endpoints.py improve code clarity and added logging based on review Co-authored-by: Célina * changes in infernce_endpoint wait function for robust behaviour and addition of test case in test_inference_endpoint for testing changes in wait function * changes in test case test_wait_update --------- Co-authored-by: Célina --- src/huggingface_hub/_inference_endpoints.py | 17 +++--- tests/test_inference_endpoints.py | 59 ++++++++++++++++++++- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/huggingface_hub/_inference_endpoints.py b/src/huggingface_hub/_inference_endpoints.py index ad5c34ad31..37733fef1b 100644 --- a/src/huggingface_hub/_inference_endpoints.py +++ b/src/huggingface_hub/_inference_endpoints.py @@ -207,16 +207,21 @@ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "Infere start = time.time() while True: - if self.url is not None: - # Means the URL is provisioned => check if the endpoint is reachable - response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token)) - if response.status_code == 200: - logger.info("Inference Endpoint is ready to be used.") - return self if self.status == InferenceEndpointStatus.FAILED: raise InferenceEndpointError( f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." ) + if self.status == InferenceEndpointStatus.UPDATE_FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to update. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.RUNNING and self.url is not None: + # Verify the endpoint is actually reachable + response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token)) + if response.status_code == 200: + logger.info("Inference Endpoint is ready to be used.") + return self + if timeout is not None: if time.time() - start > timeout: raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") diff --git a/tests/test_inference_endpoints.py b/tests/test_inference_endpoints.py index 6b44466338..265b700c0f 100644 --- a/tests/test_inference_endpoints.py +++ b/tests/test_inference_endpoints.py @@ -1,5 +1,6 @@ from datetime import datetime, timezone -from unittest.mock import Mock, patch +from itertools import chain, repeat +from unittest.mock import MagicMock, Mock, patch import pytest @@ -109,6 +110,39 @@ "targetReplica": 1, }, } +# added for test_wait_update function +MOCK_UPDATE = { + "name": "my-endpoint-name", + "type": "protected", + "accountId": None, + "provider": {"vendor": "aws", "region": "us-east-1"}, + "compute": { + "accelerator": "cpu", + "instanceType": "intel-icl", + "instanceSize": "x2", + "scaling": {"minReplica": 0, "maxReplica": 1}, + }, + "model": { + "repository": "gpt2", + "revision": "11c5a3d5811f50298f278a704980280950aedb10", + "task": "text-generation", + "framework": "pytorch", + "image": {"huggingface": {}}, + "secret": {"token": "my-token"}, + }, + "status": { + "createdAt": "2023-10-26T12:41:53.263078506Z", + "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, + "updatedAt": "2023-10-26T12:41:53.263079138Z", + "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, + "private": None, + "state": "updating", + "url": "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud", + "message": "Endpoint waiting for the update", + "readyReplica": 0, + "targetReplica": 1, + }, +} def test_from_raw_initialization(): @@ -189,7 +223,7 @@ def test_fetch(mock_get: Mock): @patch("huggingface_hub._inference_endpoints.get_session") @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") def test_wait_until_running(mock_get: Mock, mock_session: Mock): - """Test waits waits until the endpoint is ready.""" + """Test waits until the endpoint is ready.""" endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") mock_get.side_effect = [ @@ -244,6 +278,27 @@ def test_wait_failed(mock_get: Mock): endpoint.wait(refresh_every=0.001) +@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") +@patch("huggingface_hub._inference_endpoints.get_session") +def test_wait_update(mock_get_session, mock_get_inference_endpoint): + """Test that wait() returns when the endpoint transitions to running.""" + endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") + # Create an iterator that yields three MOCK_UPDATE responses,and then infinitely yields MOCK_RUNNING responses. + responses = chain( + [InferenceEndpoint.from_raw(MOCK_UPDATE, namespace="foo")] * 3, + repeat(InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo")), + ) + mock_get_inference_endpoint.side_effect = lambda *args, **kwargs: next(responses) + + # Patch the get_session().get() call to always return a fake response with status_code 200. + fake_response = MagicMock() + fake_response.status_code = 200 + mock_get_session.return_value.get.return_value = fake_response + + endpoint.wait(refresh_every=0.05) + assert endpoint.status == "running" + + @patch("huggingface_hub.hf_api.HfApi.pause_inference_endpoint") def test_pause(mock: Mock): """Test `pause` calls the correct alias."""