From a52c558ab111a5b1c6e024484fde2fc98cb1f8ef Mon Sep 17 00:00:00 2001 From: Oskar Hollmann Date: Mon, 9 Dec 2024 13:43:24 +0100 Subject: [PATCH] fix: Do not crash when running sync client in async context The asyncio code is now executed in a ThreadPoolExecutor to avoid problems with multiple event loops. --- rossum_api/elis_api_client_sync.py | 165 ++++++++++------------ tests/elis_api_client/test_client_sync.py | 38 +---- 2 files changed, 80 insertions(+), 123 deletions(-) diff --git a/rossum_api/elis_api_client_sync.py b/rossum_api/elis_api_client_sync.py index 44d5d63..c0159bc 100644 --- a/rossum_api/elis_api_client_sync.py +++ b/rossum_api/elis_api_client_sync.py @@ -2,6 +2,8 @@ import asyncio import typing +from concurrent.futures import ThreadPoolExecutor +from queue import Queue as ThreadSafeQueue from rossum_api import ElisAPIClient @@ -75,29 +77,40 @@ def __init__( self.elis_api_client = ElisAPIClient( username, password, token, base_url, http_client, deserializer ) - - try: - self.event_loop = asyncio.get_running_loop() - if self.event_loop.is_running(): - raise AsyncRuntimeError( - "Event loop is present and already running, please use async version of the client" - ) - except RuntimeError: - self.event_loop = asyncio.new_event_loop() + self.executor = ThreadPoolExecutor() def _iter_over_async(self, ait: AsyncIterator[T]) -> Iterator[T]: - ait = ait.__aiter__() - while True: + """Iterate over an async generator from sync code without materializing all items into memory.""" + queue: ThreadSafeQueue = ( + ThreadSafeQueue() + ) # To communicate with the thread executing the async generator + + async def async_iter_to_list(ait: AsyncIterator[T], queue: ThreadSafeQueue): try: - obj = self.event_loop.run_until_complete(ait.__anext__()) - yield obj - except StopAsyncIteration: + async for obj in ait: + queue.put(obj) + finally: + queue.put(None) # Signal iterator was consumed + + future = self.executor.submit(asyncio.run, async_iter_to_list(ait, queue)) # type: ignore + + # Consume the queue until completion to retain the iterator nature even in sync context + while True: + item = queue.get() + if item is None: # None is used to signal completion break + yield item + + future.result() + + def _run_coroutine(self, couroutine): + future = self.executor.submit(asyncio.run, couroutine) + return future.result() # Wait for the coroutine to complete # ##### QUEUE ##### def retrieve_queue(self, queue_id: int) -> Queue: """https://elis.rossum.ai/api/docs/#retrieve-a-queue-2.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_queue(queue_id)) + return self._run_coroutine(self.elis_api_client.retrieve_queue(queue_id)) def list_all_queues( self, @@ -112,11 +125,11 @@ def create_new_queue( data: Dict[str, Any], ) -> Queue: """https://elis.rossum.ai/api/docs/#create-new-queue.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_queue(data)) + return self._run_coroutine(self.elis_api_client.create_new_queue(data)) def delete_queue(self, queue_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-queue.""" - return self.event_loop.run_until_complete(self.elis_api_client.delete_queue(queue_id)) + return self._run_coroutine(self.elis_api_client.delete_queue(queue_id)) def import_document( self, @@ -143,7 +156,7 @@ def import_document( annotation_ids list of IDs of created annotations, respects the order of `files` argument """ - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.import_document(queue_id, files, values, metadata) ) @@ -175,7 +188,7 @@ def upload_document( Tasks can be polled using poll_task and if succeeded, will contain a link to an Upload object that contains info on uploaded documents/annotations """ - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.upload_document(queue_id, files, values, metadata) ) @@ -185,7 +198,7 @@ def retrieve_upload( ) -> Upload: """Implements https://elis.rossum.ai/api/docs/#retrieve-upload.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_upload(upload_id)) + return self._run_coroutine(self.elis_api_client.retrieve_upload(upload_id)) def export_annotations_to_json(self, queue_id: int) -> Iterator[Annotation]: """https://elis.rossum.ai/api/docs/#export-annotations. @@ -221,13 +234,11 @@ def retrieve_organization( org_id: int, ) -> Organization: """https://elis.rossum.ai/api/docs/#retrieve-an-organization.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_organization(org_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_organization(org_id)) def retrieve_own_organization(self) -> Organization: """Retrieve organization of currently logged in user.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_own_organization()) + return self._run_coroutine(self.elis_api_client.retrieve_own_organization()) # ##### SCHEMAS ##### def list_all_schemas( @@ -240,20 +251,20 @@ def list_all_schemas( def retrieve_schema(self, schema_id: int) -> Schema: """https://elis.rossum.ai/api/docs/#retrieve-a-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_schema(schema_id)) + return self._run_coroutine(self.elis_api_client.retrieve_schema(schema_id)) def create_new_schema(self, data: Dict[str, Any]) -> Schema: """https://elis.rossum.ai/api/docs/#create-a-new-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_schema(data)) + return self._run_coroutine(self.elis_api_client.create_new_schema(data)) def delete_schema(self, schema_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.delete_schema(schema_id)) + return self._run_coroutine(self.elis_api_client.delete_schema(schema_id)) # ##### ENGINES ##### def retrieve_engine(self, engine_id: int) -> Engine: """https://elis.rossum.ai/api/docs/#retrieve-a-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_engine(engine_id)) + return self._run_coroutine(self.elis_api_client.retrieve_engine(engine_id)) # ##### USERS ##### def list_all_users( @@ -266,11 +277,11 @@ def list_all_users( def retrieve_user(self, user_id: int) -> User: """https://elis.rossum.ai/api/docs/#retrieve-a-user-2.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_user(user_id)) + return self._run_coroutine(self.elis_api_client.retrieve_user(user_id)) def create_new_user(self, data: Dict[str, Any]) -> User: """https://elis.rossum.ai/api/docs/#create-new-user.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_user(data)) + return self._run_coroutine(self.elis_api_client.create_new_user(data)) # TODO: specific method in APICLient def change_user_password(self, new_password: str) -> dict: @@ -312,7 +323,7 @@ def search_for_annotations( def retrieve_annotation(self, annotation_id: int, sideloads: Sequence[str] = ()) -> Annotation: """https://elis.rossum.ai/api/docs/#retrieve-an-annotation.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.retrieve_annotation(annotation_id, sideloads) ) @@ -327,7 +338,7 @@ def poll_annotation( Sideloading is done only once after the predicate becomes true to avoid spamming the server. """ - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.poll_annotation(annotation_id, predicate, sleep_s, sideloads) ) @@ -338,9 +349,7 @@ def poll_task( sleep_s: int = 3, ) -> Task: """Poll on Task until predicate is true.""" - return self.event_loop.run_until_complete( - self.elis_api_client.poll_task(task_id, predicate, sleep_s) - ) + return self._run_coroutine(self.elis_api_client.poll_task(task_id, predicate, sleep_s)) def poll_task_until_succeeded( self, @@ -348,17 +357,17 @@ def poll_task_until_succeeded( sleep_s: int = 3, ) -> Task: """Poll on Task until it is succeeded.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.poll_task_until_succeeded(task_id, sleep_s) ) def retrieve_task(self, task_id: int) -> Task: """https://elis.rossum.ai/api/docs/#retrieve-task.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_task(task_id)) + return self._run_coroutine(self.elis_api_client.retrieve_task(task_id)) def poll_annotation_until_imported(self, annotation_id: int, **poll_kwargs: Any) -> Annotation: """A shortcut for waiting until annotation is imported.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.poll_annotation_until_imported(annotation_id, **poll_kwargs) ) @@ -366,7 +375,7 @@ def upload_and_wait_until_imported( self, queue_id: int, filepath: Union[str, pathlib.Path], filename: str, **poll_kwargs ) -> Annotation: """A shortcut for uploading a single file and waiting until its annotation is imported.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.upload_and_wait_until_imported( queue_id, filepath, filename, **poll_kwargs ) @@ -374,17 +383,15 @@ def upload_and_wait_until_imported( def start_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#start-annotation""" - self.event_loop.run_until_complete(self.elis_api_client.start_annotation(annotation_id)) + self._run_coroutine(self.elis_api_client.start_annotation(annotation_id)) def update_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-an-annotation.""" - return self.event_loop.run_until_complete( - self.elis_api_client.update_annotation(annotation_id, data) - ) + return self._run_coroutine(self.elis_api_client.update_annotation(annotation_id, data)) def update_part_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-part-of-an-annotation.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.update_part_annotation(annotation_id, data) ) @@ -392,42 +399,34 @@ def bulk_update_annotation_data( self, annotation_id: int, operations: List[Dict[str, Any]] ) -> None: """https://elis.rossum.ai/api/docs/#bulk-update-annotation-data""" - self.event_loop.run_until_complete( + self._run_coroutine( self.elis_api_client.bulk_update_annotation_data(annotation_id, operations) ) def confirm_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#confirm-annotation""" - self.event_loop.run_until_complete(self.elis_api_client.confirm_annotation(annotation_id)) + self._run_coroutine(self.elis_api_client.confirm_annotation(annotation_id)) def create_new_annotation(self, data: dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#create-an-annotation""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_annotation(data)) + return self._run_coroutine(self.elis_api_client.create_new_annotation(data)) def delete_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#switch-to-deleted""" - return self.event_loop.run_until_complete( - self.elis_api_client.delete_annotation(annotation_id) - ) + return self._run_coroutine(self.elis_api_client.delete_annotation(annotation_id)) def cancel_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#cancel-annotation""" - return self.event_loop.run_until_complete( - self.elis_api_client.cancel_annotation(annotation_id) - ) + return self._run_coroutine(self.elis_api_client.cancel_annotation(annotation_id)) # ##### DOCUMENTS ##### def retrieve_document(self, document_id: int) -> Document: """https://elis.rossum.ai/api/docs/#retrieve-a-document""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_document(document_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_document(document_id)) def retrieve_document_content(self, document_id: int) -> bytes: """https://elis.rossum.ai/api/docs/#document-content""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_document_content(document_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_document_content(document_id)) def create_new_document( self, @@ -437,7 +436,7 @@ def create_new_document( parent: Optional[str] = None, ) -> Document: """https://elis.rossum.ai/api/docs/#create-document""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.create_new_document(file_name, file_data, metadata, parent) ) @@ -452,19 +451,15 @@ def list_all_workspaces( def retrieve_workspace(self, workspace_id: int) -> Workspace: """https://elis.rossum.ai/api/docs/#retrieve-a-workspace.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_workspace(workspace_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_workspace(workspace_id)) def create_new_workspace(self, data: Dict[str, Any]) -> Workspace: """https://elis.rossum.ai/api/docs/#create-a-new-workspace.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_workspace(data)) + return self._run_coroutine(self.elis_api_client.create_new_workspace(data)) def delete_workspace(self, workspace_id: int) -> None: """https://elis.rossum.ai/api/docs/#retrieve-a-workspace.""" - return self.event_loop.run_until_complete( - self.elis_api_client.delete_workspace(workspace_id) - ) + return self._run_coroutine(self.elis_api_client.delete_workspace(workspace_id)) # ##### INBOX ##### def create_new_inbox( @@ -472,7 +467,7 @@ def create_new_inbox( data: Dict[str, Any], ) -> Inbox: """https://elis.rossum.ai/api/docs/#create-a-new-inbox.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_inbox(data)) + return self._run_coroutine(self.elis_api_client.create_new_inbox(data)) # ##### EMAIL TEMPLATES ##### def list_all_email_templates( @@ -487,15 +482,11 @@ def list_all_email_templates( def retrieve_email_template(self, email_template_id: int) -> EmailTemplate: """https://elis.rossum.ai/api/docs/#retrieve-an-email-template-object.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_email_template(email_template_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_email_template(email_template_id)) def create_new_email_template(self, data: Dict[str, Any]) -> EmailTemplate: """https://elis.rossum.ai/api/docs/#create-new-email-template-object.""" - return self.event_loop.run_until_complete( - self.elis_api_client.create_new_email_template(data) - ) + return self._run_coroutine(self.elis_api_client.create_new_email_template(data)) # ##### CONNECTORS ##### def list_all_connectors( @@ -508,13 +499,11 @@ def list_all_connectors( def retrieve_connector(self, connector_id: int) -> Connector: """https://elis.rossum.ai/api/docs/#retrieve-a-connector.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_connector(connector_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_connector(connector_id)) def create_new_connector(self, data: Dict[str, Any]) -> Connector: """https://elis.rossum.ai/api/docs/#create-a-new-connector.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_connector(data)) + return self._run_coroutine(self.elis_api_client.create_new_connector(data)) # ##### HOOKS ##### def list_all_hooks( @@ -527,21 +516,19 @@ def list_all_hooks( def retrieve_hook(self, hook_id: int) -> Hook: """https://elis.rossum.ai/api/docs/#retrieve-a-hook.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_hook(hook_id)) + return self._run_coroutine(self.elis_api_client.retrieve_hook(hook_id)) def create_new_hook(self, data: Dict[str, Any]) -> Hook: """https://elis.rossum.ai/api/docs/#create-a-new-hook.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_hook(data)) + return self._run_coroutine(self.elis_api_client.create_new_hook(data)) def update_part_hook(self, hook_id: int, data: Dict[str, Any]) -> Hook: """https://elis.rossum.ai/api/docs/#update-part-of-a-hook""" - return self.event_loop.run_until_complete( - self.elis_api_client.update_part_hook(hook_id, data) - ) + return self._run_coroutine(self.elis_api_client.update_part_hook(hook_id, data)) def delete_hook(self, hook_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-hook""" - return self.event_loop.run_until_complete(self.elis_api_client.delete_hook(hook_id)) + return self._run_coroutine(self.elis_api_client.delete_hook(hook_id)) # ##### USER ROLES ##### def list_all_user_roles( @@ -562,17 +549,13 @@ def request_json(self, method: str, *args, **kwargs) -> Dict[str, Any]: """Use to perform requests to seldomly used or experimental endpoints that do not have direct support in the client and return JSON. """ - return self.event_loop.run_until_complete( - self.elis_api_client.request_json(method, *args, **kwargs) - ) + return self._run_coroutine(self.elis_api_client.request_json(method, *args, **kwargs)) def request(self, method: str, *args, **kwargs) -> httpx.Response: """Use to perform requests to seldomly used or experimental endpoints that do not have direct support in the client and return the raw response. """ - return self.event_loop.run_until_complete( - self.elis_api_client.request(method, *args, **kwargs) - ) + return self._run_coroutine(self.elis_api_client.request(method, *args, **kwargs)) def get_token(self, refresh: bool = False) -> str: """Returns the current token. Authentication is done automatically if needed. @@ -582,4 +565,4 @@ def get_token(self, refresh: bool = False) -> str: refresh force refreshing the token """ - return self.event_loop.run_until_complete(self.elis_api_client.get_token(refresh)) + return self._run_coroutine(self.elis_api_client.get_token(refresh)) diff --git a/tests/elis_api_client/test_client_sync.py b/tests/elis_api_client/test_client_sync.py index 9214e12..98dcd8a 100644 --- a/tests/elis_api_client/test_client_sync.py +++ b/tests/elis_api_client/test_client_sync.py @@ -1,42 +1,10 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch - import httpx import pytest -from rossum_api import ElisAPIClientSync -from rossum_api.elis_api_client_sync import AsyncRuntimeError - class TestClientSync: - def test_new_event_loop(self): - with patch("asyncio.get_running_loop", side_effect=RuntimeError()), patch( - "asyncio.new_event_loop" - ) as new_event_loop_mock: - ElisAPIClientSync("", "", None) - assert new_event_loop_mock.called - - def test_existing_event_loop_not_running(self): - event_loop = MagicMock() - event_loop.is_running = MagicMock(return_value=False) - with patch("asyncio.get_running_loop", return_value=event_loop), patch( - "asyncio.new_event_loop" - ) as new_event_loop_mock: - ElisAPIClientSync("", "", None) - assert not new_event_loop_mock.called - - def test_existing_event_loop_running(self): - event_loop = MagicMock() - event_loop.is_running = MagicMock(return_value=True) - with patch("asyncio.get_running_loop", return_value=event_loop), patch( - "asyncio.new_event_loop" - ) as new_event_loop_mock: - with pytest.raises(AsyncRuntimeError): - ElisAPIClientSync("", "", None) - - assert not new_event_loop_mock.called - def test_request_paginated(self, elis_client_sync, mock_generator): client, http_client = elis_client_sync http_client.fetch_all_by_url.return_value = mock_generator({"some": "json"}) @@ -44,6 +12,12 @@ def test_request_paginated(self, elis_client_sync, mock_generator): data = client.request_paginated("hook_templates", **kwargs) assert list(data) == [{"some": "json"}] + def test_request_paginated_propagates_errors(self, elis_client_sync): + client, http_client = elis_client_sync + http_client.fetch_all_by_url.side_effect = Exception("Exception in async code.") + with pytest.raises(Exception, match="Exception in async code."): + list(client.request_paginated("hook_templates")) + def test_request_json(self, elis_client_sync): client, http_client = elis_client_sync http_client.request_json.return_value = {"some": "json"}