1+ from collections .abc import Generator
12from dataclasses import dataclass
2- from dataclasses_json import dataclass_json , Undefined # type: ignore
3+ from enum import Enum
4+ from typing import Any
5+ from urllib .parse import urlparse
6+
37import requests
8+ from dataclasses_json import Undefined , dataclass_json # type: ignore
49from requests .structures import CaseInsensitiveDict
5- from typing import Optional , Dict , Any , Union , Generator
6- from urllib .parse import urlparse
7- from enum import Enum
810
911
1012class InferenceClientError (Exception ):
@@ -14,6 +16,8 @@ class InferenceClientError(Exception):
1416
1517
1618class AsyncStatus (str , Enum ):
19+ """Async status."""
20+
1721 Initialized = 'Initialized'
1822 Queue = 'Queue'
1923 Inference = 'Inference'
@@ -23,6 +27,8 @@ class AsyncStatus(str, Enum):
2327@dataclass_json (undefined = Undefined .EXCLUDE )
2428@dataclass
2529class InferenceResponse :
30+ """Inference response."""
31+
2632 headers : CaseInsensitiveDict [str ]
2733 status_code : int
2834 status_text : str
@@ -64,6 +70,7 @@ def _is_stream_response(self, headers: CaseInsensitiveDict[str]) -> bool:
6470 )
6571
6672 def output (self , is_text : bool = False ) -> Any :
73+ """Get response output as a string or object."""
6774 try :
6875 if is_text :
6976 return self ._original_response .text
@@ -73,8 +80,8 @@ def output(self, is_text: bool = False) -> Any:
7380 if self ._is_stream_response (self ._original_response .headers ):
7481 raise InferenceClientError (
7582 'Response might be a stream, use the stream method instead'
76- )
77- raise InferenceClientError (f'Failed to parse response as JSON: { str ( e ) } ' )
83+ ) from e
84+ raise InferenceClientError (f'Failed to parse response as JSON: { e !s } ' ) from e
7885
7986 def stream (self , chunk_size : int = 512 , as_text : bool = True ) -> Generator [Any , None , None ]:
8087 """Stream the response content.
@@ -97,11 +104,12 @@ def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any,
97104
98105
99106class InferenceClient :
107+ """Inference client."""
108+
100109 def __init__ (
101110 self , inference_key : str , endpoint_base_url : str , timeout_seconds : int = 60 * 5
102111 ) -> None :
103- """
104- Initialize the InferenceClient.
112+ """Initialize the InferenceClient.
105113
106114 Args:
107115 inference_key: The authentication key for the API
@@ -136,37 +144,33 @@ def __exit__(self, exc_type, exc_val, exc_tb):
136144 self ._session .close ()
137145
138146 @property
139- def global_headers (self ) -> Dict [str , str ]:
140- """
141- Get the current global headers that will be used for all requests.
147+ def global_headers (self ) -> dict [str , str ]:
148+ """Get the current global headers that will be used for all requests.
142149
143150 Returns:
144151 Dictionary of current global headers
145152 """
146153 return self ._global_headers .copy ()
147154
148155 def set_global_header (self , key : str , value : str ) -> None :
149- """
150- Set or update a global header that will be used for all requests.
156+ """Set or update a global header that will be used for all requests.
151157
152158 Args:
153159 key: Header name
154160 value: Header value
155161 """
156162 self ._global_headers [key ] = value
157163
158- def set_global_headers (self , headers : Dict [str , str ]) -> None :
159- """
160- Set multiple global headers at once that will be used for all requests.
164+ def set_global_headers (self , headers : dict [str , str ]) -> None :
165+ """Set multiple global headers at once that will be used for all requests.
161166
162167 Args:
163168 headers: Dictionary of headers to set globally
164169 """
165170 self ._global_headers .update (headers )
166171
167172 def remove_global_header (self , key : str ) -> None :
168- """
169- Remove a global header.
173+ """Remove a global header.
170174
171175 Args:
172176 key: Header name to remove from global headers
@@ -179,10 +183,9 @@ def _build_url(self, path: str) -> str:
179183 return f'{ self .endpoint_base_url } /{ path .lstrip ("/" )} '
180184
181185 def _build_request_headers (
182- self , request_headers : Optional [Dict [str , str ]] = None
183- ) -> Dict [str , str ]:
184- """
185- Build the final headers by merging global headers with request-specific headers.
186+ self , request_headers : dict [str , str ] | None = None
187+ ) -> dict [str , str ]:
188+ """Build the final headers by merging global headers with request-specific headers.
186189
187190 Args:
188191 request_headers: Optional headers specific to this request
@@ -196,8 +199,7 @@ def _build_request_headers(
196199 return headers
197200
198201 def _make_request (self , method : str , path : str , ** kwargs ) -> requests .Response :
199- """
200- Make an HTTP request with error handling.
202+ """Make an HTTP request with error handling.
201203
202204 Args:
203205 method: HTTP method to use
@@ -221,17 +223,19 @@ def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
221223 )
222224 response .raise_for_status ()
223225 return response
224- except requests .exceptions .Timeout :
225- raise InferenceClientError (f'Request to { path } timed out after { timeout } seconds' )
226+ except requests .exceptions .Timeout as e :
227+ raise InferenceClientError (
228+ f'Request to { path } timed out after { timeout } seconds'
229+ ) from e
226230 except requests .exceptions .RequestException as e :
227- raise InferenceClientError (f'Request to { path } failed: { str ( e ) } ' )
231+ raise InferenceClientError (f'Request to { path } failed: { e !s } ' ) from e
228232
229233 def run_sync (
230234 self ,
231- data : Dict [str , Any ],
235+ data : dict [str , Any ],
232236 path : str = '' ,
233237 timeout_seconds : int = 60 * 5 ,
234- headers : Optional [ Dict [ str , str ]] = None ,
238+ headers : dict [ str , str ] | None = None ,
235239 http_method : str = 'POST' ,
236240 stream : bool = False ,
237241 ):
@@ -269,10 +273,10 @@ def run_sync(
269273
270274 def run (
271275 self ,
272- data : Dict [str , Any ],
276+ data : dict [str , Any ],
273277 path : str = '' ,
274278 timeout_seconds : int = 60 * 5 ,
275- headers : Optional [ Dict [ str , str ]] = None ,
279+ headers : dict [ str , str ] | None = None ,
276280 http_method : str = 'POST' ,
277281 no_response : bool = False ,
278282 ):
@@ -325,23 +329,25 @@ def run(
325329 def get (
326330 self ,
327331 path : str ,
328- params : Optional [ Dict [ str , Any ]] = None ,
329- headers : Optional [ Dict [ str , str ]] = None ,
330- timeout_seconds : Optional [ int ] = None ,
332+ params : dict [ str , Any ] | None = None ,
333+ headers : dict [ str , str ] | None = None ,
334+ timeout_seconds : int | None = None ,
331335 ) -> requests .Response :
336+ """Make GET request."""
332337 return self ._make_request (
333338 'GET' , path , params = params , headers = headers , timeout_seconds = timeout_seconds
334339 )
335340
336341 def post (
337342 self ,
338343 path : str ,
339- json : Optional [ Dict [ str , Any ]] = None ,
340- data : Optional [ Union [ str , Dict [str , Any ]]] = None ,
341- params : Optional [ Dict [ str , Any ]] = None ,
342- headers : Optional [ Dict [ str , str ]] = None ,
343- timeout_seconds : Optional [ int ] = None ,
344+ json : dict [ str , Any ] | None = None ,
345+ data : str | dict [str , Any ] | None = None ,
346+ params : dict [ str , Any ] | None = None ,
347+ headers : dict [ str , str ] | None = None ,
348+ timeout_seconds : int | None = None ,
344349 ) -> requests .Response :
350+ """Make POST request."""
345351 return self ._make_request (
346352 'POST' ,
347353 path ,
@@ -355,12 +361,13 @@ def post(
355361 def put (
356362 self ,
357363 path : str ,
358- json : Optional [ Dict [ str , Any ]] = None ,
359- data : Optional [ Union [ str , Dict [str , Any ]]] = None ,
360- params : Optional [ Dict [ str , Any ]] = None ,
361- headers : Optional [ Dict [ str , str ]] = None ,
362- timeout_seconds : Optional [ int ] = None ,
364+ json : dict [ str , Any ] | None = None ,
365+ data : str | dict [str , Any ] | None = None ,
366+ params : dict [ str , Any ] | None = None ,
367+ headers : dict [ str , str ] | None = None ,
368+ timeout_seconds : int | None = None ,
363369 ) -> requests .Response :
370+ """Make PUT request."""
364371 return self ._make_request (
365372 'PUT' ,
366373 path ,
@@ -374,10 +381,11 @@ def put(
374381 def delete (
375382 self ,
376383 path : str ,
377- params : Optional [ Dict [ str , Any ]] = None ,
378- headers : Optional [ Dict [ str , str ]] = None ,
379- timeout_seconds : Optional [ int ] = None ,
384+ params : dict [ str , Any ] | None = None ,
385+ headers : dict [ str , str ] | None = None ,
386+ timeout_seconds : int | None = None ,
380387 ) -> requests .Response :
388+ """Make DELETE request."""
381389 return self ._make_request (
382390 'DELETE' ,
383391 path ,
@@ -389,12 +397,13 @@ def delete(
389397 def patch (
390398 self ,
391399 path : str ,
392- json : Optional [ Dict [ str , Any ]] = None ,
393- data : Optional [ Union [ str , Dict [str , Any ]]] = None ,
394- params : Optional [ Dict [ str , Any ]] = None ,
395- headers : Optional [ Dict [ str , str ]] = None ,
396- timeout_seconds : Optional [ int ] = None ,
400+ json : dict [ str , Any ] | None = None ,
401+ data : str | dict [str , Any ] | None = None ,
402+ params : dict [ str , Any ] | None = None ,
403+ headers : dict [ str , str ] | None = None ,
404+ timeout_seconds : int | None = None ,
397405 ) -> requests .Response :
406+ """Make PATCH request."""
398407 return self ._make_request (
399408 'PATCH' ,
400409 path ,
@@ -408,10 +417,11 @@ def patch(
408417 def head (
409418 self ,
410419 path : str ,
411- params : Optional [ Dict [ str , Any ]] = None ,
412- headers : Optional [ Dict [ str , str ]] = None ,
413- timeout_seconds : Optional [ int ] = None ,
420+ params : dict [ str , Any ] | None = None ,
421+ headers : dict [ str , str ] | None = None ,
422+ timeout_seconds : int | None = None ,
414423 ) -> requests .Response :
424+ """Make HEAD request."""
415425 return self ._make_request (
416426 'HEAD' ,
417427 path ,
@@ -423,10 +433,11 @@ def head(
423433 def options (
424434 self ,
425435 path : str ,
426- params : Optional [ Dict [ str , Any ]] = None ,
427- headers : Optional [ Dict [ str , str ]] = None ,
428- timeout_seconds : Optional [ int ] = None ,
436+ params : dict [ str , Any ] | None = None ,
437+ headers : dict [ str , str ] | None = None ,
438+ timeout_seconds : int | None = None ,
429439 ) -> requests .Response :
440+ """Make OPTIONS request."""
430441 return self ._make_request (
431442 'OPTIONS' ,
432443 path ,
@@ -436,8 +447,7 @@ def options(
436447 )
437448
438449 def health (self , healthcheck_path : str = '/health' ) -> requests .Response :
439- """
440- Check the health status of the API.
450+ """Check the health status of the API.
441451
442452 Returns:
443453 requests.Response: The response from the health check
@@ -448,31 +458,32 @@ def health(self, healthcheck_path: str = '/health') -> requests.Response:
448458 try :
449459 return self .get (healthcheck_path )
450460 except InferenceClientError as e :
451- raise InferenceClientError (f'Health check failed: { str ( e ) } ' )
461+ raise InferenceClientError (f'Health check failed: { e !s } ' ) from e
452462
453463
454464@dataclass_json (undefined = Undefined .EXCLUDE )
455465@dataclass
456466class AsyncInferenceExecution :
467+ """Async inference execution."""
468+
457469 _inference_client : 'InferenceClient'
458470 id : str
459471 _status : AsyncStatus
460472 INFERENCE_ID_HEADER = 'X-Inference-Id'
461473
462474 def status (self ) -> AsyncStatus :
463- """Get the current stored status of the async inference execution. Only the status value type
475+ """Get the current stored status of the async inference execution. Only the status value type.
464476
465477 Returns:
466478 AsyncStatus: The status object
467479 """
468-
469480 return self ._status
470481
471- def status_json (self ) -> Dict [str , Any ]:
472- """Get the current status of the async inference execution. Return the status json
482+ def status_json (self ) -> dict [str , Any ]:
483+ """Get the current status of the async inference execution. Return the status json.
473484
474485 Returns:
475- Dict [str, Any]: The status response containing the execution status and other metadata
486+ dict [str, Any]: The status response containing the execution status and other metadata
476487 """
477488 url = (
478489 f'{ self ._inference_client .base_domain } /status/{ self ._inference_client .deployment_name } '
@@ -489,11 +500,11 @@ def status_json(self) -> Dict[str, Any]:
489500
490501 return response_json
491502
492- def result (self ) -> Dict [str , Any ]:
503+ def result (self ) -> dict [str , Any ]:
493504 """Get the results of the async inference execution.
494505
495506 Returns:
496- Dict [str, Any]: The results of the inference execution
507+ dict [str, Any]: The results of the inference execution
497508 """
498509 url = (
499510 f'{ self ._inference_client .base_domain } /result/{ self ._inference_client .deployment_name } '
0 commit comments