-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing the CI with the new huggingface-hub
version
#3329
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,17 +13,18 @@ | |
# permissions and limitations under the License. | ||
"""Implementation of the Hugging Face Deployment service.""" | ||
|
||
from typing import Any, Generator, Optional, Tuple | ||
from typing import Any, Dict, Generator, Optional, Tuple | ||
|
||
from huggingface_hub import ( | ||
InferenceClient, | ||
InferenceEndpoint, | ||
InferenceEndpointError, | ||
InferenceEndpointStatus, | ||
InferenceEndpointType, | ||
create_inference_endpoint, | ||
get_inference_endpoint, | ||
) | ||
from huggingface_hub.utils import HfHubHTTPError | ||
from huggingface_hub.errors import HfHubHTTPError | ||
from pydantic import Field | ||
|
||
from zenml.client import Client | ||
|
@@ -138,30 +139,67 @@ def inference_client(self) -> InferenceClient: | |
""" | ||
return self.hf_endpoint.client | ||
|
||
def _validate_endpoint_configuration(self) -> Dict[str, str]: | ||
"""Validates the configuration to provision a Huggingface service. | ||
|
||
Raises: | ||
ValueError: if there is a missing value in the configuration | ||
|
||
Returns: | ||
The validated configuration values. | ||
""" | ||
configuration = {} | ||
missing_keys = [] | ||
|
||
for k, v in { | ||
"repository": self.config.repository, | ||
"framework": self.config.framework, | ||
"accelerator": self.config.accelerator, | ||
"instance_size": self.config.instance_size, | ||
"instance_type": self.config.instance_type, | ||
"region": self.config.region, | ||
"vendor": self.config.vendor, | ||
"endpoint_type": self.config.endpoint_type, | ||
}.items(): | ||
if v is None: | ||
missing_keys.append(k) | ||
else: | ||
configuration[k] = v | ||
|
||
if missing_keys: | ||
raise ValueError( | ||
f"Missing values in the Huggingface Service " | ||
f"configuration: {', '.join(missing_keys)}" | ||
) | ||
|
||
return configuration | ||
|
||
def provision(self) -> None: | ||
"""Provision or update remote Hugging Face deployment instance. | ||
|
||
Raises: | ||
Exception: If any unexpected error while creating inference endpoint. | ||
Exception: If any unexpected error while creating inference | ||
endpoint. | ||
""" | ||
try: | ||
# Attempt to create and wait for the inference endpoint | ||
validated_config = self._validate_endpoint_configuration() | ||
|
||
hf_endpoint = create_inference_endpoint( | ||
name=self._generate_an_endpoint_name(), | ||
repository=self.config.repository, | ||
framework=self.config.framework, | ||
accelerator=self.config.accelerator, | ||
instance_size=self.config.instance_size, | ||
instance_type=self.config.instance_type, | ||
region=self.config.region, | ||
vendor=self.config.vendor, | ||
repository=validated_config["repository"], | ||
framework=validated_config["framework"], | ||
accelerator=validated_config["accelerator"], | ||
instance_size=validated_config["instance_size"], | ||
instance_type=validated_config["instance_type"], | ||
region=validated_config["region"], | ||
vendor=validated_config["vendor"], | ||
account_id=self.config.account_id, | ||
min_replica=self.config.min_replica, | ||
max_replica=self.config.max_replica, | ||
revision=self.config.revision, | ||
task=self.config.task, | ||
custom_image=self.config.custom_image, | ||
type=self.config.endpoint_type, | ||
type=InferenceEndpointType(validated_config["endpoint_type"]), | ||
token=self.get_token(), | ||
namespace=self.config.namespace, | ||
).wait(timeout=POLLING_TIMEOUT) | ||
|
@@ -172,21 +210,25 @@ def provision(self) -> None: | |
) | ||
# Catch-all for any other unexpected errors | ||
raise Exception( | ||
f"An unexpected error occurred while provisioning the Hugging Face inference endpoint: {e}" | ||
"An unexpected error occurred while provisioning the " | ||
f"Hugging Face inference endpoint: {e}" | ||
) | ||
|
||
# Check if the endpoint URL is available after provisioning | ||
if hf_endpoint.url: | ||
logger.info( | ||
f"Hugging Face inference endpoint successfully deployed and available. Endpoint URL: {hf_endpoint.url}" | ||
"Hugging Face inference endpoint successfully deployed " | ||
f"and available. Endpoint URL: {hf_endpoint.url}" | ||
) | ||
else: | ||
logger.error( | ||
"Failed to start Hugging Face inference endpoint service: No URL available, please check the Hugging Face console for more details." | ||
"Failed to start Hugging Face inference endpoint " | ||
"service: No URL available, please check the Hugging " | ||
"Face console for more details." | ||
) | ||
|
||
def check_status(self) -> Tuple[ServiceState, str]: | ||
"""Check the the current operational state of the Hugging Face deployment. | ||
"""Check the current operational state of the Hugging Face deployment. | ||
|
||
Returns: | ||
The operational state of the Hugging Face deployment and a message | ||
|
@@ -196,26 +238,29 @@ def check_status(self) -> Tuple[ServiceState, str]: | |
try: | ||
status = self.hf_endpoint.status | ||
if status == InferenceEndpointStatus.RUNNING: | ||
return (ServiceState.ACTIVE, "") | ||
return ServiceState.ACTIVE, "" | ||
|
||
elif status == InferenceEndpointStatus.SCALED_TO_ZERO: | ||
return ( | ||
ServiceState.SCALED_TO_ZERO, | ||
"Hugging Face Inference Endpoint is scaled to zero, but still running. It will be started on demand.", | ||
"Hugging Face Inference Endpoint is scaled to zero, but " | ||
"still running. It will be started on demand.", | ||
) | ||
|
||
elif status == InferenceEndpointStatus.FAILED: | ||
return ( | ||
ServiceState.ERROR, | ||
"Hugging Face Inference Endpoint deployment is inactive or not found", | ||
"Hugging Face Inference Endpoint deployment is inactive " | ||
"or not found", | ||
) | ||
elif status == InferenceEndpointStatus.PENDING: | ||
return (ServiceState.PENDING_STARTUP, "") | ||
return (ServiceState.PENDING_STARTUP, "") | ||
return ServiceState.PENDING_STARTUP, "" | ||
return ServiceState.PENDING_STARTUP, "" | ||
except (InferenceEndpointError, HfHubHTTPError): | ||
return ( | ||
ServiceState.INACTIVE, | ||
"Hugging Face Inference Endpoint deployment is inactive or not found", | ||
"Hugging Face Inference Endpoint deployment is inactive or " | ||
"not found", | ||
) | ||
|
||
def deprovision(self, force: bool = False) -> None: | ||
|
@@ -253,15 +298,13 @@ def predict(self, data: "Any", max_new_tokens: int) -> "Any": | |
) | ||
if self.prediction_url is not None: | ||
if self.hf_endpoint.task == "text-generation": | ||
result = self.inference_client.task_generation( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw @schustmi, this one confused me a bit. The inference client of Huggingface never had a method called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They probably introduced typing in their newest release, which means before the |
||
return self.inference_client.text_generation( | ||
data, max_new_tokens=max_new_tokens | ||
) | ||
else: | ||
# TODO: Add support for all different supported tasks | ||
raise NotImplementedError( | ||
"Tasks other than text-generation is not implemented." | ||
) | ||
return result | ||
# TODO: Add support for all different supported tasks | ||
raise NotImplementedError( | ||
"Tasks other than text-generation is not implemented." | ||
) | ||
|
||
def get_logs( | ||
self, follow: bool = False, tail: Optional[int] = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these attributes optional in the config if they're not actually optional when creating the inference endpoint? It feels like we could just fail much ealier and avoid this hack with dicts that prevent any type checking if we just made them required? Maybe @safoinme has an answer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also not sure myself. This was an external contribution and I am not sure what will happen for the people who already use this integration if I change optional values into required values in this configuration. As the call only happens in the
provision
call, I decided to go with this quick and dirty approach. As you said, perhaps @safoinme can shed more light on this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the base class with the optionals is also used for the component config (but then simply not used from what I can see), so I guess we need to leave it optional here and refactor at some point..