Skip to content

Commit 1af9ad0

Browse files
authored
Fixing the CI with the new huggingface-hub version (#3329)
* fixing the CI * fix it for the api docs
1 parent f292df8 commit 1af9ad0

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

docs/mocked_libs.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
"hvac",
129129
"hvac.exceptions",
130130
"huggingface_hub",
131+
"huggingface_hub.errors",
131132
"huggingface_hub.utils",
132133
"keras",
133134
"kfp",

src/zenml/integrations/huggingface/services/huggingface_deployment.py

Lines changed: 72 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# permissions and limitations under the License.
1414
"""Implementation of the Hugging Face Deployment service."""
1515

16-
from typing import Any, Generator, Optional, Tuple
16+
from typing import Any, Dict, Generator, Optional, Tuple
1717

1818
from huggingface_hub import (
1919
InferenceClient,
2020
InferenceEndpoint,
2121
InferenceEndpointError,
2222
InferenceEndpointStatus,
23+
InferenceEndpointType,
2324
create_inference_endpoint,
2425
get_inference_endpoint,
2526
)
26-
from huggingface_hub.utils import HfHubHTTPError
27+
from huggingface_hub.errors import HfHubHTTPError
2728
from pydantic import Field
2829

2930
from zenml.client import Client
@@ -138,30 +139,67 @@ def inference_client(self) -> InferenceClient:
138139
"""
139140
return self.hf_endpoint.client
140141

142+
def _validate_endpoint_configuration(self) -> Dict[str, str]:
143+
"""Validates the configuration to provision a Huggingface service.
144+
145+
Raises:
146+
ValueError: if there is a missing value in the configuration
147+
148+
Returns:
149+
The validated configuration values.
150+
"""
151+
configuration = {}
152+
missing_keys = []
153+
154+
for k, v in {
155+
"repository": self.config.repository,
156+
"framework": self.config.framework,
157+
"accelerator": self.config.accelerator,
158+
"instance_size": self.config.instance_size,
159+
"instance_type": self.config.instance_type,
160+
"region": self.config.region,
161+
"vendor": self.config.vendor,
162+
"endpoint_type": self.config.endpoint_type,
163+
}.items():
164+
if v is None:
165+
missing_keys.append(k)
166+
else:
167+
configuration[k] = v
168+
169+
if missing_keys:
170+
raise ValueError(
171+
f"Missing values in the Huggingface Service "
172+
f"configuration: {', '.join(missing_keys)}"
173+
)
174+
175+
return configuration
176+
141177
def provision(self) -> None:
142178
"""Provision or update remote Hugging Face deployment instance.
143179
144180
Raises:
145-
Exception: If any unexpected error while creating inference endpoint.
181+
Exception: If any unexpected error while creating inference
182+
endpoint.
146183
"""
147184
try:
148-
# Attempt to create and wait for the inference endpoint
185+
validated_config = self._validate_endpoint_configuration()
186+
149187
hf_endpoint = create_inference_endpoint(
150188
name=self._generate_an_endpoint_name(),
151-
repository=self.config.repository,
152-
framework=self.config.framework,
153-
accelerator=self.config.accelerator,
154-
instance_size=self.config.instance_size,
155-
instance_type=self.config.instance_type,
156-
region=self.config.region,
157-
vendor=self.config.vendor,
189+
repository=validated_config["repository"],
190+
framework=validated_config["framework"],
191+
accelerator=validated_config["accelerator"],
192+
instance_size=validated_config["instance_size"],
193+
instance_type=validated_config["instance_type"],
194+
region=validated_config["region"],
195+
vendor=validated_config["vendor"],
158196
account_id=self.config.account_id,
159197
min_replica=self.config.min_replica,
160198
max_replica=self.config.max_replica,
161199
revision=self.config.revision,
162200
task=self.config.task,
163201
custom_image=self.config.custom_image,
164-
type=self.config.endpoint_type,
202+
type=InferenceEndpointType(validated_config["endpoint_type"]),
165203
token=self.get_token(),
166204
namespace=self.config.namespace,
167205
).wait(timeout=POLLING_TIMEOUT)
@@ -172,21 +210,25 @@ def provision(self) -> None:
172210
)
173211
# Catch-all for any other unexpected errors
174212
raise Exception(
175-
f"An unexpected error occurred while provisioning the Hugging Face inference endpoint: {e}"
213+
"An unexpected error occurred while provisioning the "
214+
f"Hugging Face inference endpoint: {e}"
176215
)
177216

178217
# Check if the endpoint URL is available after provisioning
179218
if hf_endpoint.url:
180219
logger.info(
181-
f"Hugging Face inference endpoint successfully deployed and available. Endpoint URL: {hf_endpoint.url}"
220+
"Hugging Face inference endpoint successfully deployed "
221+
f"and available. Endpoint URL: {hf_endpoint.url}"
182222
)
183223
else:
184224
logger.error(
185-
"Failed to start Hugging Face inference endpoint service: No URL available, please check the Hugging Face console for more details."
225+
"Failed to start Hugging Face inference endpoint "
226+
"service: No URL available, please check the Hugging "
227+
"Face console for more details."
186228
)
187229

188230
def check_status(self) -> Tuple[ServiceState, str]:
189-
"""Check the the current operational state of the Hugging Face deployment.
231+
"""Check the current operational state of the Hugging Face deployment.
190232
191233
Returns:
192234
The operational state of the Hugging Face deployment and a message
@@ -196,26 +238,29 @@ def check_status(self) -> Tuple[ServiceState, str]:
196238
try:
197239
status = self.hf_endpoint.status
198240
if status == InferenceEndpointStatus.RUNNING:
199-
return (ServiceState.ACTIVE, "")
241+
return ServiceState.ACTIVE, ""
200242

201243
elif status == InferenceEndpointStatus.SCALED_TO_ZERO:
202244
return (
203245
ServiceState.SCALED_TO_ZERO,
204-
"Hugging Face Inference Endpoint is scaled to zero, but still running. It will be started on demand.",
246+
"Hugging Face Inference Endpoint is scaled to zero, but "
247+
"still running. It will be started on demand.",
205248
)
206249

207250
elif status == InferenceEndpointStatus.FAILED:
208251
return (
209252
ServiceState.ERROR,
210-
"Hugging Face Inference Endpoint deployment is inactive or not found",
253+
"Hugging Face Inference Endpoint deployment is inactive "
254+
"or not found",
211255
)
212256
elif status == InferenceEndpointStatus.PENDING:
213-
return (ServiceState.PENDING_STARTUP, "")
214-
return (ServiceState.PENDING_STARTUP, "")
257+
return ServiceState.PENDING_STARTUP, ""
258+
return ServiceState.PENDING_STARTUP, ""
215259
except (InferenceEndpointError, HfHubHTTPError):
216260
return (
217261
ServiceState.INACTIVE,
218-
"Hugging Face Inference Endpoint deployment is inactive or not found",
262+
"Hugging Face Inference Endpoint deployment is inactive or "
263+
"not found",
219264
)
220265

221266
def deprovision(self, force: bool = False) -> None:
@@ -253,15 +298,13 @@ def predict(self, data: "Any", max_new_tokens: int) -> "Any":
253298
)
254299
if self.prediction_url is not None:
255300
if self.hf_endpoint.task == "text-generation":
256-
result = self.inference_client.task_generation(
301+
return self.inference_client.text_generation(
257302
data, max_new_tokens=max_new_tokens
258303
)
259-
else:
260-
# TODO: Add support for all different supported tasks
261-
raise NotImplementedError(
262-
"Tasks other than text-generation is not implemented."
263-
)
264-
return result
304+
# TODO: Add support for all different supported tasks
305+
raise NotImplementedError(
306+
"Tasks other than text-generation is not implemented."
307+
)
265308

266309
def get_logs(
267310
self, follow: bool = False, tail: Optional[int] = None

0 commit comments

Comments
 (0)