From f8643a67ec4cb960c32e26a9413efd0fefd83a31 Mon Sep 17 00:00:00 2001 From: Lucain Date: Mon, 17 Jun 2024 10:38:54 +0200 Subject: [PATCH] Do not raise on `.resume()` if Inference Endpoint is already running (#2335) * Do not raise on .resume() if Inference Endpoints is already running * make style * fix test --- src/huggingface_hub/_inference_endpoints.py | 11 +++++++++-- src/huggingface_hub/hf_api.py | 19 +++++++++++++++++-- tests/test_inference_endpoints.py | 2 +- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/_inference_endpoints.py b/src/huggingface_hub/_inference_endpoints.py index 92e407b81a..cff348bca0 100644 --- a/src/huggingface_hub/_inference_endpoints.py +++ b/src/huggingface_hub/_inference_endpoints.py @@ -315,16 +315,23 @@ def pause(self) -> "InferenceEndpoint": self._populate_from_raw() return self - def resume(self) -> "InferenceEndpoint": + def resume(self, running_ok: bool = True) -> "InferenceEndpoint": """Resume the Inference Endpoint. This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the latest data from the server. + Args: + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. """ - obj = self._api.resume_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + obj = self._api.resume_inference_endpoint( + name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token + ) # type: ignore [arg-type] self.raw = obj.raw self._populate_from_raw() return self diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 54f9f79925..901bc4f4be 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -7489,7 +7489,12 @@ def pause_inference_endpoint( return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) def resume_inference_endpoint( - self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + self, + name: str, + *, + namespace: Optional[str] = None, + running_ok: bool = True, + token: Union[bool, str, None] = None, ) -> InferenceEndpoint: """Resume an Inference Endpoint. @@ -7500,6 +7505,9 @@ def resume_inference_endpoint( The name of the Inference Endpoint to resume. namespace (`str`, *optional*): The namespace in which the Inference Endpoint is located. Defaults to the current user. + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. token (Union[bool, str, None], optional): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see @@ -7515,7 +7523,14 @@ def resume_inference_endpoint( f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", headers=self._build_hf_headers(token=token), ) - hf_raise_for_status(response) + try: + hf_raise_for_status(response) + except HfHubHTTPError as error: + # If already running (and it's ok), then fetch current status and return + if running_ok and error.response.status_code == 400 and "already running" in error.response.text: + return self.get_inference_endpoint(name, namespace=namespace, token=token) + # Otherwise, raise the error + raise return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) diff --git a/tests/test_inference_endpoints.py b/tests/test_inference_endpoints.py index 000609240f..019d04f57e 100644 --- a/tests/test_inference_endpoints.py +++ b/tests/test_inference_endpoints.py @@ -256,4 +256,4 @@ def test_resume(mock: Mock): endpoint = InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo") mock.return_value = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") endpoint.resume() - mock.assert_called_once_with(namespace="foo", name="my-endpoint-name", token=None) + mock.assert_called_once_with(namespace="foo", name="my-endpoint-name", token=None, running_ok=True)