diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 177167b..42e41b5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,12 @@ repos: - - repo: https://github.com/ambv/black - rev: 22.3.0 - hooks: - - id: black - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.259 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.2 hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix] + name: sort imports with ruff + args: [--select, I, --fix] + - id: ruff-format + name: format with ruff - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.981 hooks: diff --git a/pyproject.toml b/pyproject.toml index 5d3446b..87287c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "dacite", "httpx", "tenacity", + "inflect", ] [project.optional-dependencies] @@ -24,6 +25,7 @@ tests = [ "pytest-cov", "ruff", "types-aiofiles", + "pre-commit", ] [tools.setuptools] diff --git a/rossum_api/domain_logic/annotations.py b/rossum_api/domain_logic/annotations.py new file mode 100644 index 0000000..3530f04 --- /dev/null +++ b/rossum_api/domain_logic/annotations.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from rossum_api.models import Annotation + +if TYPE_CHECKING: + from typing import Sequence + + +def validate_list_annotations_params( + sideloads: Sequence[str] = (), + content_schema_ids: Sequence[str] = (), +) -> None: + """Validate parameters to list_annotations request.""" + if sideloads and "content" in sideloads and not content_schema_ids: + raise ValueError( + 'When content sideloading is requested, "content_schema_ids" must be provided' + ) + + +def get_http_method_for_annotation_export(**filters) -> str: + """to_status filter requires a different HTTP method. + + https://elis.rossum.ai/api/docs/#export-annotations + """ + if "to_status" in filters: + return "POST" + return "GET" + + +def is_annotation_imported(annotation: Annotation) -> bool: + return annotation.status not in ("importing", "created") diff --git a/rossum_api/domain_logic/documents.py b/rossum_api/domain_logic/documents.py new file mode 100644 index 0000000..8cfef5e --- /dev/null +++ b/rossum_api/domain_logic/documents.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import json +from typing import Any, Optional + +import httpx + + +def build_create_document_params( + file_name: str, + file_data: bytes, + metadata: Optional[dict[str, Any]], + parent: Optional[str], +) -> dict[str, Any]: + metadata = metadata or {} + files: httpx._types.RequestFiles = { + "content": (file_name, file_data), + "metadata": ("", json.dumps(metadata).encode("utf-8")), + } + if parent: + files["parent"] = ("", parent) + return files diff --git a/rossum_api/domain_logic/pagination.py b/rossum_api/domain_logic/pagination.py new file mode 100644 index 0000000..82c1293 --- /dev/null +++ b/rossum_api/domain_logic/pagination.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Sequence + +DEFAULT_PAGE_SIZE = 100 + + +def build_pagination_params(ordering: Sequence[str], page_size: int = DEFAULT_PAGE_SIZE) -> dict: + """Build params used for fetching paginated resources.""" + return { + "page_size": page_size, + "ordering": ",".join(ordering), + } diff --git a/rossum_api/domain_logic/retry.py b/rossum_api/domain_logic/retry.py new file mode 100644 index 0000000..2c26ff3 --- /dev/null +++ b/rossum_api/domain_logic/retry.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import httpx + +RETRIED_HTTP_CODES = (408, 429, 500, 502, 503, 504) + + +class AlwaysRetry(Exception): + pass + + +def should_retry(exc: BaseException) -> bool: + if isinstance(exc, (AlwaysRetry, httpx.RequestError)): + return True + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code in RETRIED_HTTP_CODES + return False diff --git a/rossum_api/domain_logic/search.py b/rossum_api/domain_logic/search.py new file mode 100644 index 0000000..4a63878 --- /dev/null +++ b/rossum_api/domain_logic/search.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any, Optional + + +def validate_search_params( + query: Optional[dict] = None, + query_string: Optional[dict] = None, +): + if not query and not query_string: + raise ValueError("Either query or query_string must be provided") + + +def build_search_params( + query: Optional[dict] = None, + query_string: Optional[dict] = None, +) -> dict[str, Any]: + json_payload = {} + if query: + json_payload["query"] = query + if query_string: + json_payload["query_string"] = query_string + return json_payload diff --git a/rossum_api/domain_logic/sideloads.py b/rossum_api/domain_logic/sideloads.py new file mode 100644 index 0000000..82c3ce5 --- /dev/null +++ b/rossum_api/domain_logic/sideloads.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING + +from rossum_api.domain_logic.urls import ( + parse_annotation_id_from_datapoint_url, + parse_resource_id_from_url, +) +from rossum_api.utils import to_singular + +if TYPE_CHECKING: + from typing import Any, Sequence, Union + + +def _group_sideloads_by_annotation_id( + sideloads: Sequence[str], response_data: dict[str, Any] +) -> dict[str, dict[int, Union[dict, list]]]: + sideloads_by_id: dict[str, dict[int, Union[dict, list]]] = {} + for sideload in sideloads: + if sideload == "content": + # Datapoints from all annotations are present in response data, we have to construct + # content (list of datapoints) for each annotation. + def get_annotation_id(datapoint: dict[str, Any]) -> int: + return parse_annotation_id_from_datapoint_url(datapoint["url"]) + + sideloads_by_id[sideload] = { + k: list(v) + for k, v in itertools.groupby( + sorted(response_data[sideload], key=get_annotation_id), + key=get_annotation_id, + ) + } + else: + sideloads_by_id[sideload] = {s["id"]: s for s in response_data[sideload]} + return sideloads_by_id + + +def embed_sideloads(response_data, sideloads: Sequence[str]) -> None: + """Put sideloads into the response data.""" + sideloads_by_id = _group_sideloads_by_annotation_id(sideloads, response_data) + for result, sideload in itertools.product(response_data["results"], sideloads): + sideload_name = to_singular(sideload) + url = result[sideload_name] + if url is None: + continue + sideload_id = parse_resource_id_from_url(url) + + result[sideload_name] = sideloads_by_id[sideload].get( + sideload_id, [] + ) # `content` can have 0 datapoints, use [] default value in this case + + +def build_sideload_params(sideloads: Sequence[str], content_schema_ids: Sequence[str]) -> dict: + """Build params used for sideloading.""" + return { + "sideload": ",".join(sideloads), + "content.schema_id": ",".join(content_schema_ids), + } diff --git a/rossum_api/domain_logic/upload.py b/rossum_api/domain_logic/upload.py new file mode 100644 index 0000000..35316e3 --- /dev/null +++ b/rossum_api/domain_logic/upload.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, BinaryIO, Optional + + +def build_upload_files( + fp: BinaryIO, + filename: str, + values: Optional[dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + """Build request files for the upload endpoint.""" + files = {"content": (filename, fp.read(), "application/octet-stream")} + + # Filename of values and metadata must be "", otherwise Elis API returns HTTP 400 with body + # "Value must be valid JSON." + if values is not None: + files["values"] = ("", json.dumps(values).encode("utf-8"), "application/json") + if metadata is not None: + files["metadata"] = ("", json.dumps(metadata).encode("utf-8"), "application/json") + + return files diff --git a/rossum_api/domain_logic/urls.py b/rossum_api/domain_logic/urls.py index 11dbdbb..f574e13 100644 --- a/rossum_api/domain_logic/urls.py +++ b/rossum_api/domain_logic/urls.py @@ -3,6 +3,8 @@ import re from typing import TYPE_CHECKING +from rossum_api.api_client import Resource + if TYPE_CHECKING: from rossum_api.models import Resource diff --git a/rossum_api/dtos.py b/rossum_api/dtos.py new file mode 100644 index 0000000..5acc6a6 --- /dev/null +++ b/rossum_api/dtos.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import dataclasses + + +@dataclasses.dataclass +class Token: + token: str + + +@dataclasses.dataclass +class UserCredentials: + username: str + password: str diff --git a/rossum_api/elis_api_client.py b/rossum_api/elis_api_client.py index 852f469..ac97913 100644 --- a/rossum_api/elis_api_client.py +++ b/rossum_api/elis_api_client.py @@ -573,7 +573,7 @@ async def list_all_email_templates( self, ordering: Sequence[str] = (), **filters: Any, - ) -> AsyncIterator[Connector]: + ) -> AsyncIterator[EmailTemplate]: """https://elis.rossum.ai/api/docs/#list-all-email-templates.""" async for c in self._http_client.fetch_all(Resource.EmailTemplate, ordering, **filters): yield self._deserializer(Resource.EmailTemplate, c) diff --git a/rossum_api/elis_api_client_sync.py b/rossum_api/elis_api_client_sync.py index 0bdcda4..c2a2960 100644 --- a/rossum_api/elis_api_client_sync.py +++ b/rossum_api/elis_api_client_sync.py @@ -1,176 +1,144 @@ from __future__ import annotations -import asyncio -import typing -from concurrent.futures import ThreadPoolExecutor -from queue import Queue as ThreadSafeQueue - -from rossum_api import ElisAPIClient -from rossum_api.domain_logic.urls import DEFAULT_BASE_URL - -if typing.TYPE_CHECKING: - import pathlib - from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypeVar, - Union, - ) - - import httpx - - from rossum_api import ExportFileFormats - from rossum_api.api_client import APIClient - from rossum_api.models import Deserializer - from rossum_api.models.annotation import Annotation - from rossum_api.models.connector import Connector - from rossum_api.models.document import Document - from rossum_api.models.email_template import EmailTemplate - from rossum_api.models.engine import Engine - from rossum_api.models.group import Group - from rossum_api.models.hook import Hook - from rossum_api.models.inbox import Inbox - from rossum_api.models.organization import Organization - from rossum_api.models.queue import Queue - from rossum_api.models.schema import Schema - from rossum_api.models.task import Task - from rossum_api.models.upload import Upload - from rossum_api.models.user import User - from rossum_api.models.workspace import Workspace - - T = TypeVar("T") - - -class Sideload: - pass - - -class AsyncRuntimeError(Exception): - pass +import pathlib +import time +from pathlib import Path +from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, Union, cast + +from rossum_api import ExportFileFormats +from rossum_api.api_client import Resource +from rossum_api.domain_logic.annotations import ( + get_http_method_for_annotation_export, + is_annotation_imported, + validate_list_annotations_params, +) +from rossum_api.domain_logic.documents import build_create_document_params +from rossum_api.domain_logic.search import build_search_params, validate_search_params +from rossum_api.domain_logic.upload import build_upload_files +from rossum_api.domain_logic.urls import get_upload_url, parse_resource_id_from_url +from rossum_api.dtos import Token, UserCredentials +from rossum_api.internal_sync_client import InternalSyncRossumAPIClient +from rossum_api.models import ( + Annotation, + Connector, + Deserializer, + Document, + EmailTemplate, + Engine, + Group, + Hook, + Inbox, + Organization, + Queue, + Schema, + Task, + Upload, + User, + Workspace, + deserialize_default, +) +from rossum_api.models.task import TaskStatus class ElisAPIClientSync: def __init__( self, - username: Optional[str] = None, - password: Optional[str] = None, - token: Optional[str] = None, - base_url: str = DEFAULT_BASE_URL, - http_client: Optional[APIClient] = None, + base_url: str, + credentials: UserCredentials | Token, deserializer: Optional[Deserializer] = None, ): - """ - Parameters - ---------- - base_url - base API URL including the "/api" and version ("/v1") in the url path. For example - "https://elis.rossum.ai/api/v1" - deserializer - pass a custom deserialization callable if different model classes should be returned - """ - self.elis_api_client = ElisAPIClient( - username, password, token, base_url, http_client, deserializer - ) - # The executor is never terminated. We would either need to turn the client into a context manager which is inconvenient for users or terminate it after each request which is wasteful. Keeping one thread around seems like the lesser evil. - self.executor = ThreadPoolExecutor(max_workers=1) - - def _iter_over_async(self, ait: AsyncIterator[T]) -> Iterator[T]: - """Iterate over an async generator from sync code without materializing all items into memory.""" - queue: ThreadSafeQueue = ( - ThreadSafeQueue() - ) # To communicate with the thread executing the async generator - - async def async_iter_to_list(ait: AsyncIterator[T], queue: ThreadSafeQueue): - try: - async for obj in ait: - queue.put(obj) - finally: - queue.put(None) # Signal iterator was consumed - - future = self.executor.submit(asyncio.run, async_iter_to_list(ait, queue)) # type: ignore - - # Consume the queue until completion to retain the iterator nature even in sync context - while True: - item = queue.get() - if item is None: # None is used to signal completion - break - yield item - - future.result() - - def _run_coroutine(self, coroutine): - future = self.executor.submit(asyncio.run, coroutine) - return future.result() # Wait for the coroutine to complete - - # ##### QUEUE ##### - def retrieve_queue(self, queue_id: int) -> Queue: + self._deserializer = deserializer or deserialize_default + self.internal_client = InternalSyncRossumAPIClient(base_url, credentials) + + # ##### QUEUES ##### + + def retrieve_queue( + self, + queue_id: int, + ) -> Queue: """https://elis.rossum.ai/api/docs/#retrieve-a-queue-2.""" - return self._run_coroutine(self.elis_api_client.retrieve_queue(queue_id)) + queue = self.internal_client.fetch_resource(Resource.Queue, queue_id) + return self._deserializer(Resource.Queue, queue) - def list_all_queues( + def list_queues( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Queue]: """https://elis.rossum.ai/api/docs/#list-all-queues.""" - return self._iter_over_async(self.elis_api_client.list_all_queues(ordering, **filters)) + for q in self.internal_client.fetch_resources(Resource.Queue, ordering, **filters): + yield self._deserializer(Resource.Queue, q) - def create_new_queue( - self, - data: Dict[str, Any], - ) -> Queue: + def create_new_queue(self, data: dict[str, Any]) -> Queue: """https://elis.rossum.ai/api/docs/#create-new-queue.""" - return self._run_coroutine(self.elis_api_client.create_new_queue(data)) + queue = self.internal_client.create(Resource.Queue, data) + return self._deserializer(Resource.Queue, queue) def delete_queue(self, queue_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-queue.""" - return self._run_coroutine(self.elis_api_client.delete_queue(queue_id)) + return self.internal_client.delete(Resource.Queue, queue_id) + + def _import_document( + self, + url: str, + files: Sequence[Tuple[Union[str, Path], str]], + values: Optional[dict[str, Any]], + metadata: Optional[dict[str, Any]], + ) -> list[int]: + """Depending on the endpoint, it either returns annotation IDs, or task IDs.""" + results = [] + for file_path, filename in files: + with open(file_path, "rb") as fp: + request_files = build_upload_files(fp, filename, values, metadata) + response_data = self.internal_client.upload(url, request_files) + (result,) = response_data[ + "results" + ] # We're uploading 1 file in 1 request, we can unpack + results.append(parse_resource_id_from_url(result["annotation"])) + return results def import_document( self, queue_id: int, - files: Sequence[Tuple[Union[str, pathlib.Path], str]], - values: Optional[Dict[str, Any]] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> List[int]: + files: Sequence[Tuple[Union[str, Path], str]], + values: Optional[dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> list[int]: """https://elis.rossum.ai/api/docs/#import-a-document. Deprecated now, consider upload_document. Parameters --------- + queue_id + ID of the queue to upload the files to files 2-tuple containing current filepath and name to be used by Elis for the uploaded file - metadata - metadata will be set to newly created annotation object values may be used to initialize datapoint values by setting the value of rir_field_names in the schema + metadata + will be set to newly created annotation object Returns ------- annotation_ids list of IDs of created annotations, respects the order of `files` argument """ - return self._run_coroutine( - self.elis_api_client.import_document(queue_id, files, values, metadata) - ) + url = get_upload_url(Resource.Queue, queue_id) + return self._import_document(url, files, values, metadata) # ##### UPLOAD ##### + def upload_document( self, queue_id: int, files: Sequence[Tuple[Union[str, pathlib.Path], str]], - values: Optional[Dict[str, Any]] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> List[Task]: - """https://elis.rossum.ai/api/docs/#create-upload. + values: Optional[dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> list[Task]: + """https://elis.rossum.ai/api/docs/#create-upload + + Does the same thing as import_document method, but uses a different endpoint. Parameters --------- @@ -178,10 +146,10 @@ def upload_document( ID of the queue to upload the files to files 2-tuple containing current filepath and name to be used by Elis for the uploaded file - metadata - metadata will be set to newly created annotation object values may be used to initialize datapoint values by setting the value of rir_field_names in the schema + metadata + will be set to newly created annotation object Returns ------- @@ -190,24 +158,33 @@ def upload_document( Tasks can be polled using poll_task and if succeeded, will contain a link to an Upload object that contains info on uploaded documents/annotations """ - return self._run_coroutine( - self.elis_api_client.upload_document(queue_id, files, values, metadata) - ) + url = f"uploads?queue={queue_id}" + task_ids = self._import_document(url, files, values, metadata) + return [self.retrieve_task(task_id) for task_id in task_ids] def retrieve_upload( self, upload_id: int, ) -> Upload: """Implements https://elis.rossum.ai/api/docs/#retrieve-upload.""" + upload = self.internal_client.fetch_resource(Resource.Upload, upload_id) + return self._deserializer(Resource.Upload, upload) - return self._run_coroutine(self.elis_api_client.retrieve_upload(upload_id)) + # ##### EXPORT ##### - def export_annotations_to_json(self, queue_id: int) -> Iterator[Annotation]: + def export_annotations_to_json( + self, + queue_id: int, + ) -> Iterator[Annotation]: """https://elis.rossum.ai/api/docs/#export-annotations. JSON export is paginated and returns the result in a way similar to other list_all methods. """ - return self._iter_over_async(self.elis_api_client.export_annotations_to_json(queue_id)) + for chunk in self.internal_client.export( + Resource.Queue, queue_id, "json", get_http_method_for_annotation_export() + ): + # JSON export can be translated directly to Annotation object + yield self._deserializer(Resource.Annotation, cast(dict, chunk)) def export_annotations_to_file( self, queue_id: int, export_format: ExportFileFormats @@ -216,85 +193,97 @@ def export_annotations_to_file( XLSX/CSV/XML exports can be huge, therefore byte streaming is used to keep memory consumption low. """ - return self._iter_over_async( - self.elis_api_client.export_annotations_to_file(queue_id, export_format) - ) + for chunk in self.internal_client.export( + Resource.Queue, queue_id, str(export_format), get_http_method_for_annotation_export() + ): + yield cast(bytes, chunk) # ##### ORGANIZATIONS ##### - def list_all_organizations( + + def list_organizations( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Organization]: """https://elis.rossum.ai/api/docs/#list-all-organizations.""" - return self._iter_over_async( - self.elis_api_client.list_all_organizations(ordering, **filters) - ) + for o in self.internal_client.fetch_resources(Resource.Organization, ordering, **filters): + yield self._deserializer(Resource.Organization, o) - def retrieve_organization( - self, - org_id: int, - ) -> Organization: + def retrieve_organization(self, org_id: int) -> Organization: """https://elis.rossum.ai/api/docs/#retrieve-an-organization.""" - return self._run_coroutine(self.elis_api_client.retrieve_organization(org_id)) + organization = self.internal_client.fetch_resource(Resource.Organization, org_id) + return self._deserializer(Resource.Organization, organization) - def retrieve_own_organization(self) -> Organization: + def retrieve_my_organization(self) -> Organization: """Retrieve organization of currently logged in user.""" - return self._run_coroutine(self.elis_api_client.retrieve_own_organization()) + user: dict[Any, Any] = self.internal_client.fetch_resource(Resource.Auth, "user") + organization_id = parse_resource_id_from_url(user["organization"]) + return self.retrieve_organization(organization_id) # ##### SCHEMAS ##### - def list_all_schemas( + + def list_schemas( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Schema]: """https://elis.rossum.ai/api/docs/#list-all-schemas.""" - return self._iter_over_async(self.elis_api_client.list_all_schemas(ordering, **filters)) + for s in self.internal_client.fetch_resources(Resource.Schema, ordering, **filters): + yield self._deserializer(Resource.Schema, s) def retrieve_schema(self, schema_id: int) -> Schema: """https://elis.rossum.ai/api/docs/#retrieve-a-schema.""" - return self._run_coroutine(self.elis_api_client.retrieve_schema(schema_id)) + schema: dict[Any, Any] = self.internal_client.fetch_resource(Resource.Schema, schema_id) + return self._deserializer(Resource.Schema, schema) - def create_new_schema(self, data: Dict[str, Any]) -> Schema: + def create_new_schema(self, data: dict[str, Any]) -> Schema: """https://elis.rossum.ai/api/docs/#create-a-new-schema.""" - return self._run_coroutine(self.elis_api_client.create_new_schema(data)) + schema = self.internal_client.create(Resource.Schema, data) + return self._deserializer(Resource.Schema, schema) def delete_schema(self, schema_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-schema.""" - return self._run_coroutine(self.elis_api_client.delete_schema(schema_id)) - - # ##### ENGINES ##### - def retrieve_engine(self, engine_id: int) -> Engine: - """https://elis.rossum.ai/api/docs/#retrieve-a-schema.""" - return self._run_coroutine(self.elis_api_client.retrieve_engine(engine_id)) + return self.internal_client.delete(Resource.Schema, schema_id) # ##### USERS ##### - def list_all_users( + + def list_users( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[User]: """https://elis.rossum.ai/api/docs/#list-all-users.""" - return self._iter_over_async(self.elis_api_client.list_all_users(ordering, **filters)) + for u in self.internal_client.fetch_resources(Resource.User, ordering, **filters): + yield self._deserializer(Resource.User, u) def retrieve_user(self, user_id: int) -> User: """https://elis.rossum.ai/api/docs/#retrieve-a-user-2.""" - return self._run_coroutine(self.elis_api_client.retrieve_user(user_id)) - - def create_new_user(self, data: Dict[str, Any]) -> User: - """https://elis.rossum.ai/api/docs/#create-new-user.""" - return self._run_coroutine(self.elis_api_client.create_new_user(data)) + user = self.internal_client.fetch_resource(Resource.User, user_id) + return self._deserializer(Resource.User, user) - # TODO: specific method in APICLient + # TODO: specific method in InternalSyncRossumAPIClient def change_user_password(self, new_password: str) -> dict: - return {} + raise NotImplementedError - # TODO: specific method in APICLient + # TODO: specific method in InternalSyncRossumAPIClient def reset_user_password(self, email: str) -> dict: - return {} + raise NotImplementedError + + def create_new_user(self, data: dict[str, Any]) -> User: + """https://elis.rossum.ai/api/docs/#create-new-user.""" + user = self.internal_client.create(Resource.User, data) + return self._deserializer(Resource.User, user) # ##### ANNOTATIONS ##### - def list_all_annotations( + + def retrieve_annotation(self, annotation_id: int, sideloads: Sequence[str] = ()) -> Annotation: + """https://elis.rossum.ai/api/docs/#retrieve-an-annotation.""" + annotation = self.internal_client.fetch_resource(Resource.Annotation, annotation_id) + if sideloads: + self.internal_client.sideload(annotation, sideloads) + return self._deserializer(Resource.Annotation, annotation) + + def list_annotations( self, ordering: Sequence[str] = (), sideloads: Sequence[str] = (), @@ -302,11 +291,12 @@ def list_all_annotations( **filters: Any, ) -> Iterator[Annotation]: """https://elis.rossum.ai/api/docs/#list-all-annotations.""" - return self._iter_over_async( - self.elis_api_client.list_all_annotations( - ordering, sideloads, content_schema_ids, **filters - ) - ) + validate_list_annotations_params(sideloads, content_schema_ids) + + for annotation in self.internal_client.fetch_resources( + Resource.Annotation, ordering, sideloads, content_schema_ids, **filters + ): + yield self._deserializer(Resource.Annotation, annotation) def search_for_annotations( self, @@ -316,18 +306,18 @@ def search_for_annotations( sideloads: Sequence[str] = (), **kwargs: Any, ) -> Iterator[Annotation]: - """https://elis.rossum.ai/api/docs/internal/#search-for-annotations.""" - return self._iter_over_async( - self.elis_api_client.search_for_annotations( - query, query_string, ordering, sideloads, **kwargs - ) - ) - - def retrieve_annotation(self, annotation_id: int, sideloads: Sequence[str] = ()) -> Annotation: - """https://elis.rossum.ai/api/docs/#retrieve-an-annotation.""" - return self._run_coroutine( - self.elis_api_client.retrieve_annotation(annotation_id, sideloads) - ) + """https://elis.rossum.ai/api/docs/#search-for-annotations.""" + validate_search_params(query, query_string) + json_payload = build_search_params(query, query_string) + for annotation in self.internal_client.fetch_resources_by_url( + f"{Resource.Annotation.value}/search", + ordering, + sideloads, + json=json_payload, + method="POST", + **kwargs, + ): + yield self._deserializer(Resource.Annotation, annotation) def poll_annotation( self, @@ -340,9 +330,33 @@ def poll_annotation( Sideloading is done only once after the predicate becomes true to avoid spamming the server. """ - return self._run_coroutine( - self.elis_api_client.poll_annotation(annotation_id, predicate, sleep_s, sideloads) + resource = Resource.Annotation + + annotation_response = self.internal_client.fetch_resource(resource, annotation_id) + # Deserialize early, we want the predicate to work with Annotation instances for convenience. + annotation = self._deserializer(resource, annotation_response) + + while not predicate(annotation): + time.sleep(sleep_s) + annotation_response = self.internal_client.fetch_resource(resource, annotation_id) + annotation = self._deserializer(resource, annotation_response) + + if sideloads: + self.internal_client.sideload(annotation_response, sideloads) + return self._deserializer(resource, annotation_response) + + def poll_annotation_until_imported(self, annotation_id: int, **poll_kwargs: Any) -> Annotation: + """A shortcut for waiting until annotation is imported.""" + return self.poll_annotation(annotation_id, is_annotation_imported, **poll_kwargs) + + # ##### TASKS ##### + + def retrieve_task(self, task_id: int) -> Task: + """https://elis.rossum.ai/api/docs/#retrieve-task.""" + task = self.internal_client.fetch_resource( + Resource.Task, task_id, request_params={"no_redirect": "True"} ) + return self._deserializer(Resource.Task, task) def poll_task( self, @@ -350,8 +364,16 @@ def poll_task( predicate: Callable[[Task], bool], sleep_s: int = 3, ) -> Task: - """Poll on Task until predicate is true.""" - return self._run_coroutine(self.elis_api_client.poll_task(task_id, predicate, sleep_s)) + """Poll on Task until predicate is true. + + As with Annotation polling, there is no innate retry limit.""" + task = self.retrieve_task(task_id) + + while not predicate(task): + time.sleep(sleep_s) + task = self.retrieve_task(task_id) + + return task def poll_task_until_succeeded( self, @@ -359,212 +381,221 @@ def poll_task_until_succeeded( sleep_s: int = 3, ) -> Task: """Poll on Task until it is succeeded.""" - return self._run_coroutine( - self.elis_api_client.poll_task_until_succeeded(task_id, sleep_s) - ) - - def retrieve_task(self, task_id: int) -> Task: - """https://elis.rossum.ai/api/docs/#retrieve-task.""" - return self._run_coroutine(self.elis_api_client.retrieve_task(task_id)) - - def poll_annotation_until_imported(self, annotation_id: int, **poll_kwargs: Any) -> Annotation: - """A shortcut for waiting until annotation is imported.""" - return self._run_coroutine( - self.elis_api_client.poll_annotation_until_imported(annotation_id, **poll_kwargs) - ) + return self.poll_task(task_id, lambda a: a.status == TaskStatus.SUCCEEDED, sleep_s) def upload_and_wait_until_imported( self, queue_id: int, filepath: Union[str, pathlib.Path], filename: str, **poll_kwargs ) -> Annotation: """A shortcut for uploading a single file and waiting until its annotation is imported.""" - return self._run_coroutine( - self.elis_api_client.upload_and_wait_until_imported( - queue_id, filepath, filename, **poll_kwargs - ) - ) + (annotation_id,) = self.import_document(queue_id, [(filepath, filename)]) + return self.poll_annotation_until_imported(annotation_id, **poll_kwargs) def start_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#start-annotation""" - self._run_coroutine(self.elis_api_client.start_annotation(annotation_id)) + self.internal_client.request_json( + "POST", f"{Resource.Annotation.value}/{annotation_id}/start" + ) - def update_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation: + def update_annotation(self, annotation_id: int, data: dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-an-annotation.""" - return self._run_coroutine(self.elis_api_client.update_annotation(annotation_id, data)) + annotation = self.internal_client.replace(Resource.Annotation, annotation_id, data) + return self._deserializer(Resource.Annotation, annotation) - def update_part_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation: + + def update_part_annotation(self, annotation_id: int, data: dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-part-of-an-annotation.""" - return self._run_coroutine( - self.elis_api_client.update_part_annotation(annotation_id, data) - ) + annotation = self.internal_client.update(Resource.Annotation, annotation_id, data) + return self._deserializer(Resource.Annotation, annotation) + def bulk_update_annotation_data( - self, annotation_id: int, operations: List[Dict[str, Any]] + self, annotation_id: int, operations: list[dict[str, Any]] ) -> None: """https://elis.rossum.ai/api/docs/#bulk-update-annotation-data""" - self._run_coroutine( - self.elis_api_client.bulk_update_annotation_data(annotation_id, operations) + + self.internal_client.request_json( + "POST", + f"{Resource.Annotation.value}/{annotation_id}/content/operations", + json={"operations": operations}, ) def confirm_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#confirm-annotation""" - self._run_coroutine(self.elis_api_client.confirm_annotation(annotation_id)) + self.internal_client.request_json( + "POST", f"{Resource.Annotation.value}/{annotation_id}/confirm" + ) def create_new_annotation(self, data: dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#create-an-annotation""" - return self._run_coroutine(self.elis_api_client.create_new_annotation(data)) + annotation = self.internal_client.create(Resource.Annotation, data) + return self._deserializer(Resource.Annotation, annotation) def delete_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#switch-to-deleted""" - return self._run_coroutine(self.elis_api_client.delete_annotation(annotation_id)) + self.internal_client.request( + "POST", url=f"{Resource.Annotation.value}/{annotation_id}/delete" + ) def cancel_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#cancel-annotation""" - return self._run_coroutine(self.elis_api_client.cancel_annotation(annotation_id)) + self.internal_client.request( + "POST", url=f"{Resource.Annotation.value}/{annotation_id}/cancel" + ) # ##### DOCUMENTS ##### + def retrieve_document(self, document_id: int) -> Document: """https://elis.rossum.ai/api/docs/#retrieve-a-document""" - return self._run_coroutine(self.elis_api_client.retrieve_document(document_id)) + document: dict[Any, Any] = self.internal_client.fetch_resource( + Resource.Document, document_id + ) + return self._deserializer(Resource.Document, document) def retrieve_document_content(self, document_id: int) -> bytes: """https://elis.rossum.ai/api/docs/#document-content""" - return self._run_coroutine(self.elis_api_client.retrieve_document_content(document_id)) + document_content = self.internal_client.request( + "GET", url=f"{Resource.Document.value}/{document_id}/content" + ) + return document_content.content + def create_new_document( self, file_name: str, file_data: bytes, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, parent: Optional[str] = None, ) -> Document: """https://elis.rossum.ai/api/docs/#create-document""" - return self._run_coroutine( - self.elis_api_client.create_new_document(file_name, file_data, metadata, parent) + files = build_create_document_params(file_name, file_data, metadata, parent) + document = self.internal_client.request_json( + "POST", url=Resource.Document.value, files=files + ) + return self._deserializer(Resource.Document, document) # ##### WORKSPACES ##### - def list_all_workspaces( + + def list_workspaces( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Workspace]: """https://elis.rossum.ai/api/docs/#list-all-workspaces.""" - return self._iter_over_async(self.elis_api_client.list_all_workspaces(ordering, **filters)) + for w in self.internal_client.fetch_resources(Resource.Workspace, ordering, **filters): + yield self._deserializer(Resource.Workspace, w) def retrieve_workspace(self, workspace_id: int) -> Workspace: """https://elis.rossum.ai/api/docs/#retrieve-a-workspace.""" - return self._run_coroutine(self.elis_api_client.retrieve_workspace(workspace_id)) + workspace = self.internal_client.fetch_resource(Resource.Workspace, workspace_id) + + return self._deserializer(Resource.Workspace, workspace) - def create_new_workspace(self, data: Dict[str, Any]) -> Workspace: + def create_new_workspace(self, data: dict[str, Any]) -> Workspace: """https://elis.rossum.ai/api/docs/#create-a-new-workspace.""" - return self._run_coroutine(self.elis_api_client.create_new_workspace(data)) + workspace = self.internal_client.create(Resource.Workspace, data) + + return self._deserializer(Resource.Workspace, workspace) def delete_workspace(self, workspace_id: int) -> None: - """https://elis.rossum.ai/api/docs/#retrieve-a-workspace.""" - return self._run_coroutine(self.elis_api_client.delete_workspace(workspace_id)) + """https://elis.rossum.ai/api/docs/#delete-a-workspace.""" + return self.internal_client.delete(Resource.Workspace, workspace_id) + + # ##### ENGINE ##### + + def retrieve_engine(self, engine_id: int) -> Engine: + """https://elis.rossum.ai/api/docs/#retrieve-an-engine.""" + engine = self.internal_client.fetch_resource(Resource.Engine, engine_id) + return self._deserializer(Resource.Engine, engine) # ##### INBOX ##### - def create_new_inbox( - self, - data: Dict[str, Any], - ) -> Inbox: + + def create_new_inbox(self, data: dict[str, Any]) -> Inbox: """https://elis.rossum.ai/api/docs/#create-a-new-inbox.""" - return self._run_coroutine(self.elis_api_client.create_new_inbox(data)) + inbox = self.internal_client.create(Resource.Inbox, data) + return self._deserializer(Resource.Inbox, inbox) # ##### EMAIL TEMPLATES ##### - def list_all_email_templates( + + def list_email_templates( self, ordering: Sequence[str] = (), **filters: Any, - ) -> Iterator[Connector]: + ) -> Iterator[EmailTemplate]: """https://elis.rossum.ai/api/docs/#list-all-email-templates.""" - return self._iter_over_async( - self.elis_api_client.list_all_email_templates(ordering, **filters) - ) + for c in self.internal_client.fetch_resources(Resource.EmailTemplate, ordering, **filters): + yield self._deserializer(Resource.EmailTemplate, c) def retrieve_email_template(self, email_template_id: int) -> EmailTemplate: """https://elis.rossum.ai/api/docs/#retrieve-an-email-template-object.""" - return self._run_coroutine(self.elis_api_client.retrieve_email_template(email_template_id)) + email_template = self.internal_client.fetch_resource( + Resource.EmailTemplate, email_template_id + ) + return self._deserializer(Resource.EmailTemplate, email_template) - def create_new_email_template(self, data: Dict[str, Any]) -> EmailTemplate: + def create_new_email_template(self, data: dict[str, Any]) -> EmailTemplate: """https://elis.rossum.ai/api/docs/#create-new-email-template-object.""" - return self._run_coroutine(self.elis_api_client.create_new_email_template(data)) + email_template = self.internal_client.create(Resource.EmailTemplate, data) + return self._deserializer(Resource.EmailTemplate, email_template) # ##### CONNECTORS ##### - def list_all_connectors( + + def list_connectors( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Connector]: """https://elis.rossum.ai/api/docs/#list-all-connectors.""" - return self._iter_over_async(self.elis_api_client.list_all_connectors(ordering, **filters)) + for c in self.internal_client.fetch_resources(Resource.Connector, ordering, **filters): + yield self._deserializer(Resource.Connector, c) def retrieve_connector(self, connector_id: int) -> Connector: """https://elis.rossum.ai/api/docs/#retrieve-a-connector.""" - return self._run_coroutine(self.elis_api_client.retrieve_connector(connector_id)) + connector = self.internal_client.fetch_resource(Resource.Connector, connector_id) + return self._deserializer(Resource.Connector, connector) - def create_new_connector(self, data: Dict[str, Any]) -> Connector: + def create_new_connector(self, data: dict[str, Any]) -> Connector: """https://elis.rossum.ai/api/docs/#create-a-new-connector.""" - return self._run_coroutine(self.elis_api_client.create_new_connector(data)) + connector = self.internal_client.create(Resource.Connector, data) + return self._deserializer(Resource.Connector, connector) # ##### HOOKS ##### - def list_all_hooks( + + def list_hooks( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Hook]: """https://elis.rossum.ai/api/docs/#list-all-hooks.""" - return self._iter_over_async(self.elis_api_client.list_all_hooks(ordering, **filters)) + for h in self.internal_client.fetch_resources(Resource.Hook, ordering, **filters): + yield self._deserializer(Resource.Hook, h) def retrieve_hook(self, hook_id: int) -> Hook: """https://elis.rossum.ai/api/docs/#retrieve-a-hook.""" - return self._run_coroutine(self.elis_api_client.retrieve_hook(hook_id)) + hook = self.internal_client.fetch_resource(Resource.Hook, hook_id) + return self._deserializer(Resource.Hook, hook) - def create_new_hook(self, data: Dict[str, Any]) -> Hook: + def create_new_hook(self, data: dict[str, Any]) -> Hook: """https://elis.rossum.ai/api/docs/#create-a-new-hook.""" - return self._run_coroutine(self.elis_api_client.create_new_hook(data)) + hook = self.internal_client.create(Resource.Hook, data) + return self._deserializer(Resource.Hook, hook) - def update_part_hook(self, hook_id: int, data: Dict[str, Any]) -> Hook: + def update_part_hook(self, hook_id: int, data: dict[str, Any]) -> Hook: """https://elis.rossum.ai/api/docs/#update-part-of-a-hook""" - return self._run_coroutine(self.elis_api_client.update_part_hook(hook_id, data)) + hook = self.internal_client.update(Resource.Hook, hook_id, data) + return self._deserializer(Resource.Hook, hook) def delete_hook(self, hook_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-hook""" - return self._run_coroutine(self.elis_api_client.delete_hook(hook_id)) + return self.internal_client.delete(Resource.Hook, hook_id) # ##### USER ROLES ##### - def list_all_user_roles( + + def list_user_roles( self, ordering: Sequence[str] = (), **filters: Any, ) -> Iterator[Group]: """https://elis.rossum.ai/api/docs/#list-all-user-roles.""" - return self._iter_over_async(self.elis_api_client.list_all_user_roles(ordering, **filters)) - - def request_paginated(self, url: str, *args, **kwargs) -> Iterator[dict]: - """Use to perform requests to seldomly used or experimental endpoints with paginated response that do not have - direct support in the client and return Iterator. - """ - return self._iter_over_async(self.elis_api_client.request_paginated(url, *args, **kwargs)) - - def request_json(self, method: str, *args, **kwargs) -> Dict[str, Any]: - """Use to perform requests to seldomly used or experimental endpoints that do not have - direct support in the client and return JSON. - """ - return self._run_coroutine(self.elis_api_client.request_json(method, *args, **kwargs)) - - def request(self, method: str, *args, **kwargs) -> httpx.Response: - """Use to perform requests to seldomly used or experimental endpoints that do not have - direct support in the client and return the raw response. - """ - return self._run_coroutine(self.elis_api_client.request(method, *args, **kwargs)) - - def get_token(self, refresh: bool = False) -> str: - """Returns the current token. Authentication is done automatically if needed. - - Parameters - ---------- - refresh - force refreshing the token - """ - return self._run_coroutine(self.elis_api_client.get_token(refresh)) + for g in self.internal_client.fetch_resources(Resource.Group, ordering, **filters): + yield self._deserializer(Resource.Group, g) diff --git a/rossum_api/internal_sync_client.py b/rossum_api/internal_sync_client.py new file mode 100644 index 0000000..12aec8c --- /dev/null +++ b/rossum_api/internal_sync_client.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union + +import httpx +import tenacity + +from rossum_api import APIClientError +from rossum_api.api_client import Resource +from rossum_api.domain_logic.pagination import build_pagination_params +from rossum_api.domain_logic.retry import AlwaysRetry, should_retry +from rossum_api.domain_logic.sideloads import build_sideload_params, embed_sideloads +from rossum_api.domain_logic.urls import parse_resource_id_from_url +from rossum_api.dtos import Token, UserCredentials +from rossum_api.models import Deserializer, deserialize_default +from rossum_api.utils import enforce_domain + + +class InternalSyncRossumAPIClient: + def __init__( + self, + base_url: str, + credentials: UserCredentials | Token, + deserializer: Optional[Deserializer] = None, + timeout: Optional[float] = None, + n_retries: int = 3, + ): + self.base_url = base_url + self._deserializer = deserializer or deserialize_default + self.client = httpx.Client(timeout=timeout) + self.n_retries = n_retries + + self.token = None + self.username = None + self.password = None + if isinstance(credentials, UserCredentials): + self.username = credentials.username + self.password = credentials.password + else: + self.token = credentials.token + + def _authenticate(self) -> None: + response = self.client.post( + f"{self.base_url}/auth/login", + data={"username": self.username, "password": self.password}, + ) + self._raise_for_status(response) + self.token = response.json()["key"] + + @property + def _headers(self): + return {"Authorization": f"token {self.token}"} + + def create(self, resource: Resource, data: dict[str, Any]) -> dict[str, Any]: + """Create a new object.""" + return self.request_json("POST", resource.value, json=data) + + def replace(self, resource: Resource, id_: int, data: dict[str, Any]) -> dict[str, Any]: + """Modify an entire existing object.""" + return self.request_json("PUT", f"{resource.value}/{id_}", json=data) + + def update(self, resource: Resource, id_: int, data: dict[str, Any]) -> dict[str, Any]: + """Modify particular fields of an existing object.""" + return self.request_json("PATCH", f"{resource.value}/{id_}", json=data) + + def delete(self, resource: Resource, id_: int) -> None: + """Delete a particular object. + + Use with caution: For some objects, it triggers a cascade delete of related objects. + """ + self._request("DELETE", f"{resource.value}/{id_}") + + def upload( + self, + url: str, + files: dict[str, Any], + ) -> dict[str, Any]: + """Upload a file to a resource that supports this.""" + return self.request_json("POST", url, files=files) + + @staticmethod + def _build_export_query_params( + export_format: str, + columns: Sequence[str] = (), + **filters: Any, + ): + query_params = {"format": export_format} + filters = filters or {} + if filters: + query_params = {**query_params, **filters} + if columns: + query_params["columns"] = ",".join(columns) + return query_params + + def export( + self, + resource: Resource, + id_: int, + export_format: str, + http_method: Any, + columns: Sequence[str] = (), + **filters: Any, + ) -> Iterator[Union[dict[str, Any], bytes]]: + query_params = self._build_export_query_params(export_format, columns, **filters) + url = f"{resource.value}/{id_}/export" + + if export_format == "json": + # JSON export is paginated just like a regular fetch_all, it abuses **filters kwargs of + # fetch_all_by_url to pass export-specific query params + for result in self.fetch_all_by_url(url, method=http_method, **query_params): # type: ignore + yield result + else: + # In CSV/XML/XLSX case, all annotations are returned, i.e. the response can be large, + # chunks of bytes are yielded from HTTP stream to keep memory consumption low. + for bytes_chunk in self._stream(http_method, url, params=query_params): + yield bytes_chunk + + def _stream(self, method: str, url: str, *args, **kwargs) -> Iterator[bytes]: + """Performs a streaming HTTP call.""" + if not self.token: + self._authenticate() + + # Do not force the calling site to alway prepend the base URL + enforce_domain(url, self.base_url) + + for attempt in tenacity.Retrying( + wait=tenacity.wait_exponential_jitter(), + retry=tenacity.retry_if_exception(should_retry), + stop=tenacity.stop_after_attempt(self.n_retries), + ): + with ( + attempt, + self.client.stream( + method, url, headers=self._headers, *args, **kwargs + ) as response, + ): + if response.status_code == 401: + self._authenticate() + if attempt.retry_state.attempt_number == 1: + raise AlwaysRetry() + self._raise_for_status(response) + for chunk in response.iter_bytes(): + yield chunk + + def fetch_resource( + self, resource: Resource, id_: Union[int, str], request_params: dict[str, Any] = None + ) -> dict[str, Any]: + """Retrieve a single object in a specific resource. + + Allows passing extra params specifically to allow disabling redirects feature of Tasks. + See https://elis.rossum.ai/api/docs/#task. + If redirects are desired, our raise_for_status wrapper must account for that. + """ + return self.request_json("GET", f"{resource.value}/{id_}", params=request_params) + + def fetch_resources( + self, + resource: Resource, + ordering: Sequence[str] = (), + sideloads: Sequence[str] = (), + content_schema_ids: Sequence[str] = (), + method: str = "GET", + json: Optional[dict] = None, + **filters, + ) -> Iterator[dict[str, Any]]: + """Retrieve a list of objects in a specific resource.""" + for result in self.fetch_resources_by_url( + resource.value, + ordering, + sideloads, + content_schema_ids, + method, + json, + **filters, + ): + yield result + + def fetch_resources_by_url( + self, + url: str, + ordering: Sequence[str] = (), + sideloads: Sequence[str] = (), + content_schema_ids: Sequence[str] = (), + method: str = "GET", + json: Optional[dict] = None, + **filters, + ) -> Iterator[dict[str, Any]]: + query_params = build_pagination_params(ordering) + query_params.update(build_sideload_params(sideloads, content_schema_ids)) + query_params.update(**filters) + + return self._fetch_paginated_results(url, method, query_params, sideloads, json) + + def _fetch_paginated_results(self, url, method, query_params, sideloads, json): + first_page_results, total_pages = self._fetch_page( + url, method, query_params | {"page": 0}, sideloads, json=json + ) + + for r in first_page_results: + yield r + + for page_number in range(2, total_pages + 1): + results, _ = self._fetch_page( + url, method, query_params | {"page": page_number}, sideloads, json=json + ) + for r in results: + yield r + + def _fetch_page( + self, + url: str, + method: str, + query_params: dict[str, Any], + sideload_groups: Sequence[str], + json: Optional[dict] = None, + ) -> Tuple[List[dict[str, Any]], int]: + data = self.request_json(method, url, params=query_params, json=json) + embed_sideloads(data, sideload_groups) + return data["results"], data["pagination"]["total_pages"] + + def request_json(self, method: str, *args, **kwargs) -> dict[str, Any]: + response = self._request(method, *args, **kwargs) + if response.status_code == 204: + return {} + return response.json() + + def request(self, method: str, *args, **kwargs) -> httpx.Response: + response = self._request(method, *args, **kwargs) + return response + + def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Response: + """Performs the actual HTTP call and does error handling. + + Arguments: + ---------- + url + base URL is prepended with base_url if needed + """ + if not self.token: + self._authenticate() + + for attempt in tenacity.Retrying( + wait=tenacity.wait_exponential_jitter(), + retry=tenacity.retry_if_exception(should_retry), + stop=tenacity.stop_after_attempt(self.n_retries), + ): + with attempt: + url = enforce_domain(url, self.base_url) + response = self.client.request(method, url, headers=self._headers, *args, **kwargs) + if response.status_code == 401: + self._authenticate() + if attempt.retry_state.attempt_number == 1: + raise AlwaysRetry() + self._raise_for_status(response) + return response + + @staticmethod + def _raise_for_status(response: httpx.Response): + """Raise an exception in case of HTTP error. + + Re-pack to our own exception class to shield users from the fact that we're using + httpx which should be an implementation detail. + """ + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + content = response.content if response.stream is None else response.read() + raise APIClientError(response.status_code, content.decode("utf-8")) from e + + def sideload(self, resource: dict[str, Any], sideloads: Sequence[str]) -> None: + """Update sideloaded resources in place. + + The API does not support sideloading when fetching a single resource, we need to load + it manually. + """ + fetched_sideloads = [] + for sideload in sideloads: + sideload_url = resource[sideload] + fetched_sideloads.append( + self.fetch_resource(Resource(sideload), parse_resource_id_from_url(sideload_url)) + ) + + for sideload, fetched_sideload in zip(sideloads, fetched_sideloads): + if sideload == "content": # Content (i.e. list of sections is wrapped in a dict) + fetched_sideload = fetched_sideload["content"] + resource[sideload] = fetched_sideload diff --git a/rossum_api/models/__init__.py b/rossum_api/models/__init__.py index 7405c4a..8d3de1e 100644 --- a/rossum_api/models/__init__.py +++ b/rossum_api/models/__init__.py @@ -25,8 +25,8 @@ if TYPE_CHECKING: from typing import Any, Callable, Dict - JsonDict = Dict[str, Any] - Deserializer = Callable[[Resource, JsonDict], Any] +JsonDict = Dict[str, Any] +Deserializer = Callable[[Resource, JsonDict], Any] RESOURCE_TO_MODEL = { diff --git a/rossum_api/utils.py b/rossum_api/utils.py new file mode 100644 index 0000000..0678b04 --- /dev/null +++ b/rossum_api/utils.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import inflect + + +def enforce_domain(url: str, base_url: str) -> str: + """Make sure the url contains the domain.""" + if not url.startswith("https://") and not url.startswith("http://"): + return f"{base_url}/{url}" + return url + + +def to_singular(word): + p = inflect.engine() + singular_form = p.singular_noun(word) + return singular_form if singular_form else word diff --git a/test.py b/test.py index b8585ec..d2ba97b 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,7 @@ It could evolve in time into an E2E test. """ + from __future__ import annotations import asyncio diff --git a/tests/e2e.py b/tests/e2e.py index daa48d5..a449f0a 100644 --- a/tests/e2e.py +++ b/tests/e2e.py @@ -1,11 +1,11 @@ """Integration tests. - These test do not run with the rest of the tests (and did not run in previous versions) - because of the filename. To manually run them, you need to: - - set envars ROSSUM_TOKEN, ROSSUM_BASE_URL and ROSSUM_ORGANIZATION_URL - - pytest tests/e2e.py +These test do not run with the rest of the tests (and did not run in previous versions) +because of the filename. To manually run them, you need to: +- set envars ROSSUM_TOKEN, ROSSUM_BASE_URL and ROSSUM_ORGANIZATION_URL +- pytest tests/e2e.py - In case of permission issues these tests will fail during cleanup. +In case of permission issues these tests will fail during cleanup. """ from __future__ import annotations