13
13
# permissions and limitations under the License.
14
14
"""Implementation of the Hugging Face Deployment service."""
15
15
16
- from typing import Any , Generator , Optional , Tuple
16
+ from typing import Any , Dict , Generator , Optional , Tuple
17
17
18
18
from huggingface_hub import (
19
19
InferenceClient ,
20
20
InferenceEndpoint ,
21
21
InferenceEndpointError ,
22
22
InferenceEndpointStatus ,
23
+ InferenceEndpointType ,
23
24
create_inference_endpoint ,
24
25
get_inference_endpoint ,
25
26
)
26
- from huggingface_hub .utils import HfHubHTTPError
27
+ from huggingface_hub .errors import HfHubHTTPError
27
28
from pydantic import Field
28
29
29
30
from zenml .client import Client
@@ -138,30 +139,67 @@ def inference_client(self) -> InferenceClient:
138
139
"""
139
140
return self .hf_endpoint .client
140
141
142
+ def _validate_endpoint_configuration (self ) -> Dict [str , str ]:
143
+ """Validates the configuration to provision a Huggingface service.
144
+
145
+ Raises:
146
+ ValueError: if there is a missing value in the configuration
147
+
148
+ Returns:
149
+ The validated configuration values.
150
+ """
151
+ configuration = {}
152
+ missing_keys = []
153
+
154
+ for k , v in {
155
+ "repository" : self .config .repository ,
156
+ "framework" : self .config .framework ,
157
+ "accelerator" : self .config .accelerator ,
158
+ "instance_size" : self .config .instance_size ,
159
+ "instance_type" : self .config .instance_type ,
160
+ "region" : self .config .region ,
161
+ "vendor" : self .config .vendor ,
162
+ "endpoint_type" : self .config .endpoint_type ,
163
+ }.items ():
164
+ if v is None :
165
+ missing_keys .append (k )
166
+ else :
167
+ configuration [k ] = v
168
+
169
+ if missing_keys :
170
+ raise ValueError (
171
+ f"Missing values in the Huggingface Service "
172
+ f"configuration: { ', ' .join (missing_keys )} "
173
+ )
174
+
175
+ return configuration
176
+
141
177
def provision (self ) -> None :
142
178
"""Provision or update remote Hugging Face deployment instance.
143
179
144
180
Raises:
145
- Exception: If any unexpected error while creating inference endpoint.
181
+ Exception: If any unexpected error while creating inference
182
+ endpoint.
146
183
"""
147
184
try :
148
- # Attempt to create and wait for the inference endpoint
185
+ validated_config = self ._validate_endpoint_configuration ()
186
+
149
187
hf_endpoint = create_inference_endpoint (
150
188
name = self ._generate_an_endpoint_name (),
151
- repository = self . config . repository ,
152
- framework = self . config . framework ,
153
- accelerator = self . config . accelerator ,
154
- instance_size = self . config . instance_size ,
155
- instance_type = self . config . instance_type ,
156
- region = self . config . region ,
157
- vendor = self . config . vendor ,
189
+ repository = validated_config [ " repository" ] ,
190
+ framework = validated_config [ " framework" ] ,
191
+ accelerator = validated_config [ " accelerator" ] ,
192
+ instance_size = validated_config [ " instance_size" ] ,
193
+ instance_type = validated_config [ " instance_type" ] ,
194
+ region = validated_config [ " region" ] ,
195
+ vendor = validated_config [ " vendor" ] ,
158
196
account_id = self .config .account_id ,
159
197
min_replica = self .config .min_replica ,
160
198
max_replica = self .config .max_replica ,
161
199
revision = self .config .revision ,
162
200
task = self .config .task ,
163
201
custom_image = self .config .custom_image ,
164
- type = self . config . endpoint_type ,
202
+ type = InferenceEndpointType ( validated_config [ " endpoint_type" ]) ,
165
203
token = self .get_token (),
166
204
namespace = self .config .namespace ,
167
205
).wait (timeout = POLLING_TIMEOUT )
@@ -172,21 +210,25 @@ def provision(self) -> None:
172
210
)
173
211
# Catch-all for any other unexpected errors
174
212
raise Exception (
175
- f"An unexpected error occurred while provisioning the Hugging Face inference endpoint: { e } "
213
+ "An unexpected error occurred while provisioning the "
214
+ f"Hugging Face inference endpoint: { e } "
176
215
)
177
216
178
217
# Check if the endpoint URL is available after provisioning
179
218
if hf_endpoint .url :
180
219
logger .info (
181
- f"Hugging Face inference endpoint successfully deployed and available. Endpoint URL: { hf_endpoint .url } "
220
+ "Hugging Face inference endpoint successfully deployed "
221
+ f"and available. Endpoint URL: { hf_endpoint .url } "
182
222
)
183
223
else :
184
224
logger .error (
185
- "Failed to start Hugging Face inference endpoint service: No URL available, please check the Hugging Face console for more details."
225
+ "Failed to start Hugging Face inference endpoint "
226
+ "service: No URL available, please check the Hugging "
227
+ "Face console for more details."
186
228
)
187
229
188
230
def check_status (self ) -> Tuple [ServiceState , str ]:
189
- """Check the the current operational state of the Hugging Face deployment.
231
+ """Check the current operational state of the Hugging Face deployment.
190
232
191
233
Returns:
192
234
The operational state of the Hugging Face deployment and a message
@@ -196,26 +238,29 @@ def check_status(self) -> Tuple[ServiceState, str]:
196
238
try :
197
239
status = self .hf_endpoint .status
198
240
if status == InferenceEndpointStatus .RUNNING :
199
- return ( ServiceState .ACTIVE , "" )
241
+ return ServiceState .ACTIVE , ""
200
242
201
243
elif status == InferenceEndpointStatus .SCALED_TO_ZERO :
202
244
return (
203
245
ServiceState .SCALED_TO_ZERO ,
204
- "Hugging Face Inference Endpoint is scaled to zero, but still running. It will be started on demand." ,
246
+ "Hugging Face Inference Endpoint is scaled to zero, but "
247
+ "still running. It will be started on demand." ,
205
248
)
206
249
207
250
elif status == InferenceEndpointStatus .FAILED :
208
251
return (
209
252
ServiceState .ERROR ,
210
- "Hugging Face Inference Endpoint deployment is inactive or not found" ,
253
+ "Hugging Face Inference Endpoint deployment is inactive "
254
+ "or not found" ,
211
255
)
212
256
elif status == InferenceEndpointStatus .PENDING :
213
- return ( ServiceState .PENDING_STARTUP , "" )
214
- return ( ServiceState .PENDING_STARTUP , "" )
257
+ return ServiceState .PENDING_STARTUP , ""
258
+ return ServiceState .PENDING_STARTUP , ""
215
259
except (InferenceEndpointError , HfHubHTTPError ):
216
260
return (
217
261
ServiceState .INACTIVE ,
218
- "Hugging Face Inference Endpoint deployment is inactive or not found" ,
262
+ "Hugging Face Inference Endpoint deployment is inactive or "
263
+ "not found" ,
219
264
)
220
265
221
266
def deprovision (self , force : bool = False ) -> None :
@@ -253,15 +298,13 @@ def predict(self, data: "Any", max_new_tokens: int) -> "Any":
253
298
)
254
299
if self .prediction_url is not None :
255
300
if self .hf_endpoint .task == "text-generation" :
256
- result = self .inference_client .task_generation (
301
+ return self .inference_client .text_generation (
257
302
data , max_new_tokens = max_new_tokens
258
303
)
259
- else :
260
- # TODO: Add support for all different supported tasks
261
- raise NotImplementedError (
262
- "Tasks other than text-generation is not implemented."
263
- )
264
- return result
304
+ # TODO: Add support for all different supported tasks
305
+ raise NotImplementedError (
306
+ "Tasks other than text-generation is not implemented."
307
+ )
265
308
266
309
def get_logs (
267
310
self , follow : bool = False , tail : Optional [int ] = None
0 commit comments