Skip to content

Commit

Permalink
refactor: Backport domain logic to async clients
Browse files Browse the repository at this point in the history
  • Loading branch information
Simona Nemeckova committed Dec 11, 2024
1 parent 1b7496c commit 70bf8d2
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 72 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ requires-python = ">= 3.8"
dependencies = [
"aiofiles",
"dacite",
"inflect",
"httpx",
"tenacity",
"inflect",
]

[project.optional-dependencies]
Expand Down
80 changes: 46 additions & 34 deletions rossum_api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
import asyncio
import functools
import itertools
import json
import logging
import typing

import httpx
import tenacity

from rossum_api.domain_logic.annotations import get_http_method_for_annotation_export
from rossum_api.domain_logic.pagination import build_pagination_params
from rossum_api.domain_logic.retry import AlwaysRetry, should_retry
from rossum_api.domain_logic.upload import build_upload_files
from rossum_api.domain_logic.urls import (
DEFAULT_BASE_URL,
build_export_url,
build_full_login_url,
build_upload_url,
parse_annotation_id_from_datapoint_url,
parse_resource_id_from_url,
)
from rossum_api.utils import enforce_domain

if typing.TYPE_CHECKING:
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -106,10 +109,10 @@ class APIClient:

def __init__(
self,
base_url: str,
username: Optional[str] = None,
password: Optional[str] = None,
token: Optional[str] = None,
base_url: str = DEFAULT_BASE_URL,
timeout: Optional[float] = None,
n_retries: int = 3,
retry_backoff_factor: float = 1.0,
Expand Down Expand Up @@ -228,9 +231,8 @@ async def fetch_all_by_url(
filters
mapping from resource field to value used to filter records
"""
query_params = {
"page_size": 100,
"ordering": ",".join(ordering),
pagination_params = build_pagination_params(ordering)
query_params = pagination_params | {
"sideload": ",".join(sideloads),
"content.schema_id": ",".join(content_schema_ids),
**filters,
Expand Down Expand Up @@ -300,7 +302,6 @@ def annotation_id(datapoint):
if url is None:
continue
sideload_id = parse_resource_id_from_url(url)

result[sideload_name] = sideloads_by_id[sideload_group].get(
sideload_id, []
) # `content` can have 0 datapoints, use [] default value in this case
Expand Down Expand Up @@ -345,15 +346,9 @@ async def upload(
may be used to initialize values of the object created from the uploaded file,
semantics is different for each resource
"""
files = {"content": (filename, await fp.read(), "application/octet-stream")}

# Filename of values and metadata must be "", otherwise Elis API returns HTTP 400 with body
# "Value must be valid JSON."
if values is not None:
files["values"] = ("", json.dumps(values).encode("utf-8"), "application/json")
if metadata is not None:
files["metadata"] = ("", json.dumps(metadata).encode("utf-8"), "application/json")
return await self.request_json("POST", build_upload_url(resource, id_), files=files)
url = build_upload_url(resource, id_)
files = build_upload_files(await fp.read(), filename, values, metadata)
return await self.request_json("POST", url, files=files)

async def export(
self,
Expand All @@ -371,7 +366,7 @@ async def export(
query_params["columns"] = ",".join(columns)
url = build_export_url(resource, id_)
# to_status parameter is valid only in POST requests, we can use GET in all other cases
method = "POST" if "to_status" in filters else "GET"
method = get_http_method_for_annotation_export(**filters)
if export_format == "json":
# JSON export is paginated just like a regular fetch_all, it abuses **filters kwargs of
# fetch_all_by_url to pass export-specific query params
Expand Down Expand Up @@ -434,7 +429,6 @@ def should_retry_request(exc: BaseException) -> bool:
reraise=True,
)

@authenticate_if_needed
async def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Response:
"""Performs the actual HTTP call and does error handling.
Expand All @@ -444,29 +438,47 @@ async def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Respon
base URL is prepended with base_url if needed
"""
# Do not force the calling site to always prepend the base URL
if not url.startswith("https://") and not url.startswith("http://"):
url = f"{self.base_url}/{url}"
headers = kwargs.pop("headers", {})
headers["Authorization"] = f"token {self.token}"
if not self.token:
await self._authenticate()
url = enforce_domain(url, self.base_url)

async for attempt in self._retrying():
for attempt in tenacity.Retrying(
wait=tenacity.wait_exponential_jitter(),
retry=tenacity.retry_if_exception(should_retry),
stop=tenacity.stop_after_attempt(self.n_retries),
):
with attempt:
response = await self.client.request(method, url, headers=headers, *args, **kwargs)
response = await self.client.request(
method, url, headers=self._headers, *args, **kwargs
)
if response.status_code == 401:
await self._authenticate()
if attempt.retry_state.attempt_number == 1:
raise AlwaysRetry()
await self._raise_for_status(response)
return response
return response

@authenticate_generator_if_needed
async def _stream(self, method: str, url: str, *args, **kwargs) -> AsyncIterator[bytes]:
"""Performs a streaming HTTP call."""
# Do not force the calling site to alway prepend the base URL
if not url.startswith("https://") and not url.startswith("http://"):
url = f"{self.base_url}/{url}"
headers = kwargs.pop("headers", {})
headers["Authorization"] = f"token {self.token}"
async with self.client.stream(method, url, headers=headers, *args, **kwargs) as response:
await self._raise_for_status(response)
async for chunk in response.aiter_bytes():
yield chunk
url = enforce_domain(url, self.base_url)

for attempt in tenacity.Retrying(
wait=tenacity.wait_exponential_jitter(),
retry=tenacity.retry_if_exception(should_retry),
stop=tenacity.stop_after_attempt(self.n_retries),
):
with attempt:
async with self.client.stream(
method, url, headers=self._headers, *args, **kwargs
) as response:
if response.status_code == 401:
await self._authenticate()
if attempt.retry_state.attempt_number == 1:
raise AlwaysRetry()
await self._raise_for_status(response)
async for chunk in response.aiter_bytes():
yield chunk

async def _raise_for_status(self, response: httpx.Response):
"""Raise an exception in case of HTTP error.
Expand Down
6 changes: 3 additions & 3 deletions rossum_api/domain_logic/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, BinaryIO, Optional
from typing import Any, Optional


def build_upload_files(
fp: BinaryIO,
file_content: bytes,
filename: str,
values: Optional[dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""Build request files for the upload endpoint."""
files = {"content": (filename, fp.read(), "application/octet-stream")}
files = {"content": (filename, file_content, "application/octet-stream")}

# Filename of values and metadata must be "", otherwise Elis API returns HTTP 400 with body
# "Value must be valid JSON."
Expand Down
42 changes: 16 additions & 26 deletions rossum_api/elis_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
import aiofiles

from rossum_api.api_client import APIClient
from rossum_api.domain_logic.annotations import (
is_annotation_imported,
validate_list_annotations_params,
)
from rossum_api.domain_logic.documents import build_create_document_params
from rossum_api.domain_logic.resources import Resource
from rossum_api.domain_logic.urls import DEFAULT_BASE_URL
from rossum_api.domain_logic.search import build_search_params, validate_search_params
from rossum_api.domain_logic.urls import DEFAULT_BASE_URL, parse_resource_id_from_url
from rossum_api.models import deserialize_default
from rossum_api.models.task import TaskStatus

Expand Down Expand Up @@ -66,7 +72,7 @@ def __init__(
deserializer
pass a custom deserialization callable if different model classes should be returned
"""
self._http_client = http_client or APIClient(username, password, token, base_url)
self._http_client = http_client or APIClient(base_url, username, password, token)
self._deserializer = deserializer or deserialize_default

# ##### QUEUE #####
Expand Down Expand Up @@ -141,7 +147,7 @@ async def _upload(self, file, queue_id, filename, values, metadata) -> int:
Resource.Queue, queue_id, fp, filename, values, metadata
)
(result,) = results["results"] # We're uploading 1 file in 1 request, we can unpack
return int(result["annotation"].split("/")[-1])
return parse_resource_id_from_url(result["annotation"])

# ##### UPLOAD #####
async def upload_document(
Expand Down Expand Up @@ -200,7 +206,7 @@ async def _create_upload(
files["metadata"] = ("", json.dumps(metadata).encode("utf-8"), "application/json")

task_url = await self.request_json("POST", url, files=files)
task_id = task_url["url"].split("/")[-1]
task_id = parse_resource_id_from_url(task_url["url"])

return await self.retrieve_task(task_id)

Expand Down Expand Up @@ -253,7 +259,7 @@ async def retrieve_organization(self, org_id: int) -> Organization:
async def retrieve_own_organization(self) -> Organization:
"""Retrieve organization of currently logged in user."""
user: Dict[Any, Any] = await self._http_client.fetch_one(Resource.Auth, "user")
organization_id = user["organization"].split("/")[-1]
organization_id = parse_resource_id_from_url(user["organization"])
return await self.retrieve_organization(organization_id)

# ##### SCHEMAS #####
Expand Down Expand Up @@ -321,10 +327,7 @@ async def list_all_annotations(
**filters: Any,
) -> AsyncIterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#list-all-annotations."""
if sideloads and "content" in sideloads and not content_schema_ids:
raise ValueError(
'When content sideloading is requested, "content_schema_ids" must be provided'
)
validate_list_annotations_params(sideloads, content_schema_ids)
async for a in self._http_client.fetch_all(
Resource.Annotation, ordering, sideloads, content_schema_ids, **filters
):
Expand All @@ -339,13 +342,8 @@ async def search_for_annotations(
**kwargs: Any,
) -> AsyncIterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#search-for-annotations."""
if not query and not query_string:
raise ValueError("Either query or query_string must be provided")
json_payload = {}
if query:
json_payload["query"] = query
if query_string:
json_payload["query_string"] = query_string
validate_search_params(query, query_string)
json_payload = build_search_params(query, query_string)

async for a in self._http_client.fetch_all_by_url(
f"{Resource.Annotation.value}/search",
Expand Down Expand Up @@ -394,9 +392,7 @@ async def poll_annotation_until_imported(
self, annotation_id: int, **poll_kwargs: Any
) -> Annotation:
"""A shortcut for waiting until annotation is imported."""
return await self.poll_annotation(
annotation_id, lambda a: a.status not in ("importing", "created"), **poll_kwargs
)
return await self.poll_annotation(annotation_id, is_annotation_imported, **poll_kwargs)

async def poll_task(
self,
Expand Down Expand Up @@ -514,13 +510,7 @@ async def create_new_document(
parent: Optional[str] = None,
) -> Document:
"""https://elis.rossum.ai/api/docs/#create-document"""
metadata = metadata or {}
files: httpx._types.RequestFiles = {
"content": (file_name, file_data),
"metadata": ("", json.dumps(metadata).encode("utf-8")),
}
if parent:
files["parent"] = ("", parent)
files = build_create_document_params(file_name, file_data, metadata, parent)

document = await self._http_client.request_json(
"POST", url=Resource.Document.value, files=files
Expand Down
10 changes: 3 additions & 7 deletions rossum_api/elis_api_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rossum_api.domain_logic.documents import build_create_document_params
from rossum_api.domain_logic.search import build_search_params, validate_search_params
from rossum_api.domain_logic.upload import build_upload_files
from rossum_api.domain_logic.urls import get_upload_url, parse_resource_id_from_url
from rossum_api.domain_logic.urls import build_upload_url, parse_resource_id_from_url
from rossum_api.dtos import Token, UserCredentials
from rossum_api.internal_sync_client import InternalSyncRossumAPIClient
from rossum_api.models import (
Expand Down Expand Up @@ -89,7 +89,7 @@ def _import_document(
results = []
for file_path, filename in files:
with open(file_path, "rb") as fp:
request_files = build_upload_files(fp, filename, values, metadata)
request_files = build_upload_files(fp.read(), filename, values, metadata)
response_data = self.internal_client.upload(url, request_files)
(result,) = response_data[
"results"
Expand Down Expand Up @@ -124,7 +124,7 @@ def import_document(
annotation_ids
list of IDs of created annotations, respects the order of `files` argument
"""
url = get_upload_url(Resource.Queue, queue_id)
url = build_upload_url(Resource.Queue, queue_id)
return self._import_document(url, files, values, metadata)

# ##### UPLOAD #####
Expand Down Expand Up @@ -401,13 +401,11 @@ def update_annotation(self, annotation_id: int, data: dict[str, Any]) -> Annotat
annotation = self.internal_client.replace(Resource.Annotation, annotation_id, data)
return self._deserializer(Resource.Annotation, annotation)


def update_part_annotation(self, annotation_id: int, data: dict[str, Any]) -> Annotation:
"""https://elis.rossum.ai/api/docs/#update-part-of-an-annotation."""
annotation = self.internal_client.update(Resource.Annotation, annotation_id, data)
return self._deserializer(Resource.Annotation, annotation)


def bulk_update_annotation_data(
self, annotation_id: int, operations: list[dict[str, Any]]
) -> None:
Expand Down Expand Up @@ -458,7 +456,6 @@ def retrieve_document_content(self, document_id: int) -> bytes:
)
return document_content.content


def create_new_document(
self,
file_name: str,
Expand All @@ -470,7 +467,6 @@ def create_new_document(
files = build_create_document_params(file_name, file_data, metadata, parent)
document = self.internal_client.request_json(
"POST", url=Resource.Document.value, files=files

)
return self._deserializer(Resource.Document, document)

Expand Down
2 changes: 1 addition & 1 deletion rossum_api/internal_sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ def _request(self, method: str, url: str, *args, **kwargs) -> httpx.Response:
"""
if not self.token:
self._authenticate()
url = enforce_domain(url, self.base_url)

for attempt in tenacity.Retrying(
wait=tenacity.wait_exponential_jitter(),
retry=tenacity.retry_if_exception(should_retry),
stop=tenacity.stop_after_attempt(self.n_retries),
):
with attempt:
url = enforce_domain(url, self.base_url)
response = self.client.request(method, url, headers=self._headers, *args, **kwargs)
if response.status_code == 401:
self._authenticate()
Expand Down

0 comments on commit 70bf8d2

Please sign in to comment.