Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ DataCrunch's Public API documentation [is available here](https://api.datacrunch
datacrunch = DataCrunchClient(CLIENT_ID, CLIENT_SECRET)

# Get all SSH keys
ssh_keys = datacrunch.ssh_keys.get()
ssh_keys = list(map(lambda key: key.id, ssh_keys))
ssh_keys = [key.id for key in datacrunch.ssh_keys.get()]

# Create a new instance
instance = datacrunch.instances.create(instance_type='1V100.6V',
Expand Down
147 changes: 79 additions & 68 deletions datacrunch/InferenceClient/inference_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses_json import dataclass_json, Undefined # type: ignore
from enum import Enum
from typing import Any
from urllib.parse import urlparse

import requests
from dataclasses_json import Undefined, dataclass_json # type: ignore
from requests.structures import CaseInsensitiveDict
from typing import Optional, Dict, Any, Union, Generator
from urllib.parse import urlparse
from enum import Enum


class InferenceClientError(Exception):
Expand All @@ -14,6 +16,8 @@ class InferenceClientError(Exception):


class AsyncStatus(str, Enum):
"""Async status."""

Initialized = 'Initialized'
Queue = 'Queue'
Inference = 'Inference'
Expand All @@ -23,6 +27,8 @@ class AsyncStatus(str, Enum):
@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class InferenceResponse:
"""Inference response."""

headers: CaseInsensitiveDict[str]
status_code: int
status_text: str
Expand Down Expand Up @@ -64,6 +70,7 @@ def _is_stream_response(self, headers: CaseInsensitiveDict[str]) -> bool:
)

def output(self, is_text: bool = False) -> Any:
"""Get response output as a string or object."""
try:
if is_text:
return self._original_response.text
Expand All @@ -73,8 +80,8 @@ def output(self, is_text: bool = False) -> Any:
if self._is_stream_response(self._original_response.headers):
raise InferenceClientError(
'Response might be a stream, use the stream method instead'
)
raise InferenceClientError(f'Failed to parse response as JSON: {str(e)}')
) from e
raise InferenceClientError(f'Failed to parse response as JSON: {e!s}') from e

def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any, None, None]:
"""Stream the response content.
Expand All @@ -97,11 +104,12 @@ def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any,


class InferenceClient:
"""Inference client."""

def __init__(
self, inference_key: str, endpoint_base_url: str, timeout_seconds: int = 60 * 5
) -> None:
"""
Initialize the InferenceClient.
"""Initialize the InferenceClient.

Args:
inference_key: The authentication key for the API
Expand Down Expand Up @@ -136,37 +144,33 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._session.close()

@property
def global_headers(self) -> Dict[str, str]:
"""
Get the current global headers that will be used for all requests.
def global_headers(self) -> dict[str, str]:
"""Get the current global headers that will be used for all requests.

Returns:
Dictionary of current global headers
"""
return self._global_headers.copy()

def set_global_header(self, key: str, value: str) -> None:
"""
Set or update a global header that will be used for all requests.
"""Set or update a global header that will be used for all requests.

Args:
key: Header name
value: Header value
"""
self._global_headers[key] = value

def set_global_headers(self, headers: Dict[str, str]) -> None:
"""
Set multiple global headers at once that will be used for all requests.
def set_global_headers(self, headers: dict[str, str]) -> None:
"""Set multiple global headers at once that will be used for all requests.

Args:
headers: Dictionary of headers to set globally
"""
self._global_headers.update(headers)

def remove_global_header(self, key: str) -> None:
"""
Remove a global header.
"""Remove a global header.

Args:
key: Header name to remove from global headers
Expand All @@ -179,10 +183,9 @@ def _build_url(self, path: str) -> str:
return f'{self.endpoint_base_url}/{path.lstrip("/")}'

def _build_request_headers(
self, request_headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""
Build the final headers by merging global headers with request-specific headers.
self, request_headers: dict[str, str] | None = None
) -> dict[str, str]:
"""Build the final headers by merging global headers with request-specific headers.

Args:
request_headers: Optional headers specific to this request
Expand All @@ -196,8 +199,7 @@ def _build_request_headers(
return headers

def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
"""
Make an HTTP request with error handling.
"""Make an HTTP request with error handling.

Args:
method: HTTP method to use
Expand All @@ -221,17 +223,19 @@ def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
)
response.raise_for_status()
return response
except requests.exceptions.Timeout:
raise InferenceClientError(f'Request to {path} timed out after {timeout} seconds')
except requests.exceptions.Timeout as e:
raise InferenceClientError(
f'Request to {path} timed out after {timeout} seconds'
) from e
except requests.exceptions.RequestException as e:
raise InferenceClientError(f'Request to {path} failed: {str(e)}')
raise InferenceClientError(f'Request to {path} failed: {e!s}') from e

def run_sync(
self,
data: Dict[str, Any],
data: dict[str, Any],
path: str = '',
timeout_seconds: int = 60 * 5,
headers: Optional[Dict[str, str]] = None,
headers: dict[str, str] | None = None,
http_method: str = 'POST',
stream: bool = False,
):
Expand Down Expand Up @@ -269,10 +273,10 @@ def run_sync(

def run(
self,
data: Dict[str, Any],
data: dict[str, Any],
path: str = '',
timeout_seconds: int = 60 * 5,
headers: Optional[Dict[str, str]] = None,
headers: dict[str, str] | None = None,
http_method: str = 'POST',
no_response: bool = False,
):
Expand Down Expand Up @@ -325,23 +329,25 @@ def run(
def get(
self,
path: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make GET request."""
return self._make_request(
'GET', path, params=params, headers=headers, timeout_seconds=timeout_seconds
)

def post(
self,
path: str,
json: Optional[Dict[str, Any]] = None,
data: Optional[Union[str, Dict[str, Any]]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
json: dict[str, Any] | None = None,
data: str | dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make POST request."""
return self._make_request(
'POST',
path,
Expand All @@ -355,12 +361,13 @@ def post(
def put(
self,
path: str,
json: Optional[Dict[str, Any]] = None,
data: Optional[Union[str, Dict[str, Any]]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
json: dict[str, Any] | None = None,
data: str | dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make PUT request."""
return self._make_request(
'PUT',
path,
Expand All @@ -374,10 +381,11 @@ def put(
def delete(
self,
path: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make DELETE request."""
return self._make_request(
'DELETE',
path,
Expand All @@ -389,12 +397,13 @@ def delete(
def patch(
self,
path: str,
json: Optional[Dict[str, Any]] = None,
data: Optional[Union[str, Dict[str, Any]]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
json: dict[str, Any] | None = None,
data: str | dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make PATCH request."""
return self._make_request(
'PATCH',
path,
Expand All @@ -408,10 +417,11 @@ def patch(
def head(
self,
path: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make HEAD request."""
return self._make_request(
'HEAD',
path,
Expand All @@ -423,10 +433,11 @@ def head(
def options(
self,
path: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout_seconds: Optional[int] = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout_seconds: int | None = None,
) -> requests.Response:
"""Make OPTIONS request."""
return self._make_request(
'OPTIONS',
path,
Expand All @@ -436,8 +447,7 @@ def options(
)

def health(self, healthcheck_path: str = '/health') -> requests.Response:
"""
Check the health status of the API.
"""Check the health status of the API.

Returns:
requests.Response: The response from the health check
Expand All @@ -448,31 +458,32 @@ def health(self, healthcheck_path: str = '/health') -> requests.Response:
try:
return self.get(healthcheck_path)
except InferenceClientError as e:
raise InferenceClientError(f'Health check failed: {str(e)}')
raise InferenceClientError(f'Health check failed: {e!s}') from e


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class AsyncInferenceExecution:
"""Async inference execution."""

_inference_client: 'InferenceClient'
id: str
_status: AsyncStatus
INFERENCE_ID_HEADER = 'X-Inference-Id'

def status(self) -> AsyncStatus:
"""Get the current stored status of the async inference execution. Only the status value type
"""Get the current stored status of the async inference execution. Only the status value type.

Returns:
AsyncStatus: The status object
"""

return self._status

def status_json(self) -> Dict[str, Any]:
"""Get the current status of the async inference execution. Return the status json
def status_json(self) -> dict[str, Any]:
"""Get the current status of the async inference execution. Return the status json.

Returns:
Dict[str, Any]: The status response containing the execution status and other metadata
dict[str, Any]: The status response containing the execution status and other metadata
"""
url = (
f'{self._inference_client.base_domain}/status/{self._inference_client.deployment_name}'
Expand All @@ -489,11 +500,11 @@ def status_json(self) -> Dict[str, Any]:

return response_json

def result(self) -> Dict[str, Any]:
def result(self) -> dict[str, Any]:
"""Get the results of the async inference execution.

Returns:
Dict[str, Any]: The results of the inference execution
dict[str, Any]: The results of the inference execution
"""
url = (
f'{self._inference_client.base_domain}/result/{self._inference_client.deployment_name}'
Expand Down
3 changes: 1 addition & 2 deletions datacrunch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from datacrunch.datacrunch import DataCrunchClient

from datacrunch._version import __version__
from datacrunch.datacrunch import DataCrunchClient
7 changes: 4 additions & 3 deletions datacrunch/authentication/authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import requests
import time

import requests

from datacrunch.http_client.http_client import handle_error

TOKEN_ENDPOINT = '/oauth2/token'
Expand All @@ -10,15 +11,15 @@


class AuthenticationService:
"""A service for client authentication"""
"""A service for client authentication."""

def __init__(self, client_id: str, client_secret: str, base_url: str) -> None:
self._base_url = base_url
self._client_id = client_id
self._client_secret = client_secret

def authenticate(self) -> dict:
"""Authenticate the client and store the access & refresh tokens
"""Authenticate the client and store the access & refresh tokens.

returns an authentication data dictionary with the following schema:
{
Expand Down
Loading