From 70bf8d2cb4a06ab261f655cf1005982b8a99d1fc Mon Sep 17 00:00:00 2001 From: Simona Nemeckova Date: Wed, 11 Dec 2024 14:52:33 +0100 Subject: [PATCH] refactor: Backport domain logic to async clients --- pyproject.toml | 2 +- rossum_api/api_client.py | 80 +++++++++++++++++------------- rossum_api/domain_logic/upload.py | 6 +-- rossum_api/elis_api_client.py | 42 ++++++---------- rossum_api/elis_api_client_sync.py | 10 ++-- rossum_api/internal_sync_client.py | 2 +- 6 files changed, 70 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87287c7..5dc791b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ requires-python = ">= 3.8" dependencies = [ "aiofiles", "dacite", + "inflect", "httpx", "tenacity", - "inflect", ] [project.optional-dependencies] diff --git a/rossum_api/api_client.py b/rossum_api/api_client.py index 31dbc0a..8e8adb8 100644 --- a/rossum_api/api_client.py +++ b/rossum_api/api_client.py @@ -3,21 +3,24 @@ import asyncio import functools import itertools -import json import logging import typing import httpx import tenacity +from rossum_api.domain_logic.annotations import get_http_method_for_annotation_export +from rossum_api.domain_logic.pagination import build_pagination_params +from rossum_api.domain_logic.retry import AlwaysRetry, should_retry +from rossum_api.domain_logic.upload import build_upload_files from rossum_api.domain_logic.urls import ( - DEFAULT_BASE_URL, build_export_url, build_full_login_url, build_upload_url, parse_annotation_id_from_datapoint_url, parse_resource_id_from_url, ) +from rossum_api.utils import enforce_domain if typing.TYPE_CHECKING: from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union @@ -106,10 +109,10 @@ class APIClient: def __init__( self, + base_url: str, username: Optional[str] = None, password: Optional[str] = None, token: Optional[str] = None, - base_url: str = DEFAULT_BASE_URL, timeout: Optional[float] = None, n_retries: int = 3, retry_backoff_factor: float = 1.0, @@ -228,9 +231,8 @@ async def fetch_all_by_url( filters mapping from resource field to value used to filter records """ - query_params = { - "page_size": 100, - "ordering": ",".join(ordering), + pagination_params = build_pagination_params(ordering) + query_params = pagination_params | { "sideload": ",".join(sideloads), "content.schema_id": ",".join(content_schema_ids), **filters, @@ -300,7 +302,6 @@ def annotation_id(datapoint): if url is None: continue sideload_id = parse_resource_id_from_url(url) - result[sideload_name] = sideloads_by_id[sideload_group].get( sideload_id, [] ) # `content` can have 0 datapoints, use [] default value in this case @@ -345,15 +346,9 @@ async def upload( may be used to initialize values of the object created from the uploaded file, semantics is different for each resource """ - files = {"content": (filename, await fp.read(), "application/octet-stream")} - - # Filename of values and metadata must be "", otherwise Elis API returns HTTP 400 with body - # "Value must be valid JSON." - if values is not None: - files["values"] = ("", json.dumps(values).encode("utf-8"), "application/json") - if metadata is not None: - files["metadata"] = ("", json.dumps(metadata).encode("utf-8"), "application/json") - return await self.request_json("POST", build_upload_url(resource, id_), files=files) + url = build_upload_url(resource, id_) + files = build_upload_files(await fp.read(), filename, values, metadata) + return await self.request_json("POST", url, files=files) async def export( self, @@ -371,7 +366,7 @@ async def export( query_params["columns"] = ",".join(columns) url = build_export_url(resource, id_) # to_status parameter is valid only in POST requests, we can use GET in all other cases - method = "POST" if "to_status" in filters else "GET" + method = get_http_method_for_annotation_export(**filters) if export_format == "json": # JSON export is paginated just like a regular fetch_all, it abuses **filters kwargs of # fetch_all_by_url to pass export-specific query params @@ -434,7 +429,6 @@ def should_retry_request(exc: BaseException) -> bool: reraise=True, ) - @authenticate_if_needed async def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Response: """Performs the actual HTTP call and does error handling. @@ -444,29 +438,47 @@ async def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Respon base URL is prepended with base_url if needed """ # Do not force the calling site to always prepend the base URL - if not url.startswith("https://") and not url.startswith("http://"): - url = f"{self.base_url}/{url}" - headers = kwargs.pop("headers", {}) - headers["Authorization"] = f"token {self.token}" + if not self.token: + await self._authenticate() + url = enforce_domain(url, self.base_url) - async for attempt in self._retrying(): + for attempt in tenacity.Retrying( + wait=tenacity.wait_exponential_jitter(), + retry=tenacity.retry_if_exception(should_retry), + stop=tenacity.stop_after_attempt(self.n_retries), + ): with attempt: - response = await self.client.request(method, url, headers=headers, *args, **kwargs) + response = await self.client.request( + method, url, headers=self._headers, *args, **kwargs + ) + if response.status_code == 401: + await self._authenticate() + if attempt.retry_state.attempt_number == 1: + raise AlwaysRetry() await self._raise_for_status(response) - return response + return response - @authenticate_generator_if_needed async def _stream(self, method: str, url: str, *args, **kwargs) -> AsyncIterator[bytes]: """Performs a streaming HTTP call.""" # Do not force the calling site to alway prepend the base URL - if not url.startswith("https://") and not url.startswith("http://"): - url = f"{self.base_url}/{url}" - headers = kwargs.pop("headers", {}) - headers["Authorization"] = f"token {self.token}" - async with self.client.stream(method, url, headers=headers, *args, **kwargs) as response: - await self._raise_for_status(response) - async for chunk in response.aiter_bytes(): - yield chunk + url = enforce_domain(url, self.base_url) + + for attempt in tenacity.Retrying( + wait=tenacity.wait_exponential_jitter(), + retry=tenacity.retry_if_exception(should_retry), + stop=tenacity.stop_after_attempt(self.n_retries), + ): + with attempt: + async with self.client.stream( + method, url, headers=self._headers, *args, **kwargs + ) as response: + if response.status_code == 401: + await self._authenticate() + if attempt.retry_state.attempt_number == 1: + raise AlwaysRetry() + await self._raise_for_status(response) + async for chunk in response.aiter_bytes(): + yield chunk async def _raise_for_status(self, response: httpx.Response): """Raise an exception in case of HTTP error. diff --git a/rossum_api/domain_logic/upload.py b/rossum_api/domain_logic/upload.py index 35316e3..c4064e2 100644 --- a/rossum_api/domain_logic/upload.py +++ b/rossum_api/domain_logic/upload.py @@ -4,17 +4,17 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, BinaryIO, Optional + from typing import Any, Optional def build_upload_files( - fp: BinaryIO, + file_content: bytes, filename: str, values: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: """Build request files for the upload endpoint.""" - files = {"content": (filename, fp.read(), "application/octet-stream")} + files = {"content": (filename, file_content, "application/octet-stream")} # Filename of values and metadata must be "", otherwise Elis API returns HTTP 400 with body # "Value must be valid JSON." diff --git a/rossum_api/elis_api_client.py b/rossum_api/elis_api_client.py index ac97913..d8460f5 100644 --- a/rossum_api/elis_api_client.py +++ b/rossum_api/elis_api_client.py @@ -8,8 +8,14 @@ import aiofiles from rossum_api.api_client import APIClient +from rossum_api.domain_logic.annotations import ( + is_annotation_imported, + validate_list_annotations_params, +) +from rossum_api.domain_logic.documents import build_create_document_params from rossum_api.domain_logic.resources import Resource -from rossum_api.domain_logic.urls import DEFAULT_BASE_URL +from rossum_api.domain_logic.search import build_search_params, validate_search_params +from rossum_api.domain_logic.urls import DEFAULT_BASE_URL, parse_resource_id_from_url from rossum_api.models import deserialize_default from rossum_api.models.task import TaskStatus @@ -66,7 +72,7 @@ def __init__( deserializer pass a custom deserialization callable if different model classes should be returned """ - self._http_client = http_client or APIClient(username, password, token, base_url) + self._http_client = http_client or APIClient(base_url, username, password, token) self._deserializer = deserializer or deserialize_default # ##### QUEUE ##### @@ -141,7 +147,7 @@ async def _upload(self, file, queue_id, filename, values, metadata) -> int: Resource.Queue, queue_id, fp, filename, values, metadata ) (result,) = results["results"] # We're uploading 1 file in 1 request, we can unpack - return int(result["annotation"].split("/")[-1]) + return parse_resource_id_from_url(result["annotation"]) # ##### UPLOAD ##### async def upload_document( @@ -200,7 +206,7 @@ async def _create_upload( files["metadata"] = ("", json.dumps(metadata).encode("utf-8"), "application/json") task_url = await self.request_json("POST", url, files=files) - task_id = task_url["url"].split("/")[-1] + task_id = parse_resource_id_from_url(task_url["url"]) return await self.retrieve_task(task_id) @@ -253,7 +259,7 @@ async def retrieve_organization(self, org_id: int) -> Organization: async def retrieve_own_organization(self) -> Organization: """Retrieve organization of currently logged in user.""" user: Dict[Any, Any] = await self._http_client.fetch_one(Resource.Auth, "user") - organization_id = user["organization"].split("/")[-1] + organization_id = parse_resource_id_from_url(user["organization"]) return await self.retrieve_organization(organization_id) # ##### SCHEMAS ##### @@ -321,10 +327,7 @@ async def list_all_annotations( **filters: Any, ) -> AsyncIterator[Annotation]: """https://elis.rossum.ai/api/docs/#list-all-annotations.""" - if sideloads and "content" in sideloads and not content_schema_ids: - raise ValueError( - 'When content sideloading is requested, "content_schema_ids" must be provided' - ) + validate_list_annotations_params(sideloads, content_schema_ids) async for a in self._http_client.fetch_all( Resource.Annotation, ordering, sideloads, content_schema_ids, **filters ): @@ -339,13 +342,8 @@ async def search_for_annotations( **kwargs: Any, ) -> AsyncIterator[Annotation]: """https://elis.rossum.ai/api/docs/#search-for-annotations.""" - if not query and not query_string: - raise ValueError("Either query or query_string must be provided") - json_payload = {} - if query: - json_payload["query"] = query - if query_string: - json_payload["query_string"] = query_string + validate_search_params(query, query_string) + json_payload = build_search_params(query, query_string) async for a in self._http_client.fetch_all_by_url( f"{Resource.Annotation.value}/search", @@ -394,9 +392,7 @@ async def poll_annotation_until_imported( self, annotation_id: int, **poll_kwargs: Any ) -> Annotation: """A shortcut for waiting until annotation is imported.""" - return await self.poll_annotation( - annotation_id, lambda a: a.status not in ("importing", "created"), **poll_kwargs - ) + return await self.poll_annotation(annotation_id, is_annotation_imported, **poll_kwargs) async def poll_task( self, @@ -514,13 +510,7 @@ async def create_new_document( parent: Optional[str] = None, ) -> Document: """https://elis.rossum.ai/api/docs/#create-document""" - metadata = metadata or {} - files: httpx._types.RequestFiles = { - "content": (file_name, file_data), - "metadata": ("", json.dumps(metadata).encode("utf-8")), - } - if parent: - files["parent"] = ("", parent) + files = build_create_document_params(file_name, file_data, metadata, parent) document = await self._http_client.request_json( "POST", url=Resource.Document.value, files=files diff --git a/rossum_api/elis_api_client_sync.py b/rossum_api/elis_api_client_sync.py index c2a2960..2f42104 100644 --- a/rossum_api/elis_api_client_sync.py +++ b/rossum_api/elis_api_client_sync.py @@ -15,7 +15,7 @@ from rossum_api.domain_logic.documents import build_create_document_params from rossum_api.domain_logic.search import build_search_params, validate_search_params from rossum_api.domain_logic.upload import build_upload_files -from rossum_api.domain_logic.urls import get_upload_url, parse_resource_id_from_url +from rossum_api.domain_logic.urls import build_upload_url, parse_resource_id_from_url from rossum_api.dtos import Token, UserCredentials from rossum_api.internal_sync_client import InternalSyncRossumAPIClient from rossum_api.models import ( @@ -89,7 +89,7 @@ def _import_document( results = [] for file_path, filename in files: with open(file_path, "rb") as fp: - request_files = build_upload_files(fp, filename, values, metadata) + request_files = build_upload_files(fp.read(), filename, values, metadata) response_data = self.internal_client.upload(url, request_files) (result,) = response_data[ "results" @@ -124,7 +124,7 @@ def import_document( annotation_ids list of IDs of created annotations, respects the order of `files` argument """ - url = get_upload_url(Resource.Queue, queue_id) + url = build_upload_url(Resource.Queue, queue_id) return self._import_document(url, files, values, metadata) # ##### UPLOAD ##### @@ -401,13 +401,11 @@ def update_annotation(self, annotation_id: int, data: dict[str, Any]) -> Annotat annotation = self.internal_client.replace(Resource.Annotation, annotation_id, data) return self._deserializer(Resource.Annotation, annotation) - def update_part_annotation(self, annotation_id: int, data: dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-part-of-an-annotation.""" annotation = self.internal_client.update(Resource.Annotation, annotation_id, data) return self._deserializer(Resource.Annotation, annotation) - def bulk_update_annotation_data( self, annotation_id: int, operations: list[dict[str, Any]] ) -> None: @@ -458,7 +456,6 @@ def retrieve_document_content(self, document_id: int) -> bytes: ) return document_content.content - def create_new_document( self, file_name: str, @@ -470,7 +467,6 @@ def create_new_document( files = build_create_document_params(file_name, file_data, metadata, parent) document = self.internal_client.request_json( "POST", url=Resource.Document.value, files=files - ) return self._deserializer(Resource.Document, document) diff --git a/rossum_api/internal_sync_client.py b/rossum_api/internal_sync_client.py index 12aec8c..9616472 100644 --- a/rossum_api/internal_sync_client.py +++ b/rossum_api/internal_sync_client.py @@ -238,6 +238,7 @@ def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Response: """ if not self.token: self._authenticate() + url = enforce_domain(url, self.base_url) for attempt in tenacity.Retrying( wait=tenacity.wait_exponential_jitter(), @@ -245,7 +246,6 @@ def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Response: stop=tenacity.stop_after_attempt(self.n_retries), ): with attempt: - url = enforce_domain(url, self.base_url) response = self.client.request(method, url, headers=self._headers, *args, **kwargs) if response.status_code == 401: self._authenticate()