diff --git a/rossum_api/elis_api_client_sync.py b/rossum_api/elis_api_client_sync.py index 44d5d63..ca01f05 100644 --- a/rossum_api/elis_api_client_sync.py +++ b/rossum_api/elis_api_client_sync.py @@ -2,6 +2,7 @@ import asyncio import typing +from concurrent.futures import ThreadPoolExecutor from rossum_api import ElisAPIClient @@ -75,29 +76,24 @@ def __init__( self.elis_api_client = ElisAPIClient( username, password, token, base_url, http_client, deserializer ) - - try: - self.event_loop = asyncio.get_running_loop() - if self.event_loop.is_running(): - raise AsyncRuntimeError( - "Event loop is present and already running, please use async version of the client" - ) - except RuntimeError: - self.event_loop = asyncio.new_event_loop() + self.executor = ThreadPoolExecutor() def _iter_over_async(self, ait: AsyncIterator[T]) -> Iterator[T]: - ait = ait.__aiter__() - while True: - try: - obj = self.event_loop.run_until_complete(ait.__anext__()) - yield obj - except StopAsyncIteration: - break + async def async_iter_to_list(ait: AsyncIterator[T]): + # TODO this materializes the whole generator into memory :( + return [obj async for obj in ait] + + future = self.executor.submit(asyncio.run, async_iter_to_list(ait)) # type: ignore + yield from future.result() # type: ignore + + def _run_coroutine(self, couroutine): + future = self.executor.submit(asyncio.run, couroutine) + return future.result() # Wait for the coroutine to complete # ##### QUEUE ##### def retrieve_queue(self, queue_id: int) -> Queue: """https://elis.rossum.ai/api/docs/#retrieve-a-queue-2.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_queue(queue_id)) + return self._run_coroutine(self.elis_api_client.retrieve_queue(queue_id)) def list_all_queues( self, @@ -112,11 +108,11 @@ def create_new_queue( data: Dict[str, Any], ) -> Queue: """https://elis.rossum.ai/api/docs/#create-new-queue.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_queue(data)) + return self._run_coroutine(self.elis_api_client.create_new_queue(data)) def delete_queue(self, queue_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-queue.""" - return self.event_loop.run_until_complete(self.elis_api_client.delete_queue(queue_id)) + return self._run_coroutine(self.elis_api_client.delete_queue(queue_id)) def import_document( self, @@ -143,7 +139,7 @@ def import_document( annotation_ids list of IDs of created annotations, respects the order of `files` argument """ - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.import_document(queue_id, files, values, metadata) ) @@ -175,7 +171,7 @@ def upload_document( Tasks can be polled using poll_task and if succeeded, will contain a link to an Upload object that contains info on uploaded documents/annotations """ - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.upload_document(queue_id, files, values, metadata) ) @@ -185,7 +181,7 @@ def retrieve_upload( ) -> Upload: """Implements https://elis.rossum.ai/api/docs/#retrieve-upload.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_upload(upload_id)) + return self._run_coroutine(self.elis_api_client.retrieve_upload(upload_id)) def export_annotations_to_json(self, queue_id: int) -> Iterator[Annotation]: """https://elis.rossum.ai/api/docs/#export-annotations. @@ -221,13 +217,11 @@ def retrieve_organization( org_id: int, ) -> Organization: """https://elis.rossum.ai/api/docs/#retrieve-an-organization.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_organization(org_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_organization(org_id)) def retrieve_own_organization(self) -> Organization: """Retrieve organization of currently logged in user.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_own_organization()) + return self._run_coroutine(self.elis_api_client.retrieve_own_organization()) # ##### SCHEMAS ##### def list_all_schemas( @@ -240,20 +234,20 @@ def list_all_schemas( def retrieve_schema(self, schema_id: int) -> Schema: """https://elis.rossum.ai/api/docs/#retrieve-a-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_schema(schema_id)) + return self._run_coroutine(self.elis_api_client.retrieve_schema(schema_id)) def create_new_schema(self, data: Dict[str, Any]) -> Schema: """https://elis.rossum.ai/api/docs/#create-a-new-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_schema(data)) + return self._run_coroutine(self.elis_api_client.create_new_schema(data)) def delete_schema(self, schema_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.delete_schema(schema_id)) + return self._run_coroutine(self.elis_api_client.delete_schema(schema_id)) # ##### ENGINES ##### def retrieve_engine(self, engine_id: int) -> Engine: """https://elis.rossum.ai/api/docs/#retrieve-a-schema.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_engine(engine_id)) + return self._run_coroutine(self.elis_api_client.retrieve_engine(engine_id)) # ##### USERS ##### def list_all_users( @@ -266,11 +260,11 @@ def list_all_users( def retrieve_user(self, user_id: int) -> User: """https://elis.rossum.ai/api/docs/#retrieve-a-user-2.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_user(user_id)) + return self._run_coroutine(self.elis_api_client.retrieve_user(user_id)) def create_new_user(self, data: Dict[str, Any]) -> User: """https://elis.rossum.ai/api/docs/#create-new-user.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_user(data)) + return self._run_coroutine(self.elis_api_client.create_new_user(data)) # TODO: specific method in APICLient def change_user_password(self, new_password: str) -> dict: @@ -312,7 +306,7 @@ def search_for_annotations( def retrieve_annotation(self, annotation_id: int, sideloads: Sequence[str] = ()) -> Annotation: """https://elis.rossum.ai/api/docs/#retrieve-an-annotation.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.retrieve_annotation(annotation_id, sideloads) ) @@ -327,7 +321,7 @@ def poll_annotation( Sideloading is done only once after the predicate becomes true to avoid spamming the server. """ - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.poll_annotation(annotation_id, predicate, sleep_s, sideloads) ) @@ -338,9 +332,7 @@ def poll_task( sleep_s: int = 3, ) -> Task: """Poll on Task until predicate is true.""" - return self.event_loop.run_until_complete( - self.elis_api_client.poll_task(task_id, predicate, sleep_s) - ) + return self._run_coroutine(self.elis_api_client.poll_task(task_id, predicate, sleep_s)) def poll_task_until_succeeded( self, @@ -348,17 +340,17 @@ def poll_task_until_succeeded( sleep_s: int = 3, ) -> Task: """Poll on Task until it is succeeded.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.poll_task_until_succeeded(task_id, sleep_s) ) def retrieve_task(self, task_id: int) -> Task: """https://elis.rossum.ai/api/docs/#retrieve-task.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_task(task_id)) + return self._run_coroutine(self.elis_api_client.retrieve_task(task_id)) def poll_annotation_until_imported(self, annotation_id: int, **poll_kwargs: Any) -> Annotation: """A shortcut for waiting until annotation is imported.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.poll_annotation_until_imported(annotation_id, **poll_kwargs) ) @@ -366,7 +358,7 @@ def upload_and_wait_until_imported( self, queue_id: int, filepath: Union[str, pathlib.Path], filename: str, **poll_kwargs ) -> Annotation: """A shortcut for uploading a single file and waiting until its annotation is imported.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.upload_and_wait_until_imported( queue_id, filepath, filename, **poll_kwargs ) @@ -374,17 +366,15 @@ def upload_and_wait_until_imported( def start_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#start-annotation""" - self.event_loop.run_until_complete(self.elis_api_client.start_annotation(annotation_id)) + self._run_coroutine(self.elis_api_client.start_annotation(annotation_id)) def update_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-an-annotation.""" - return self.event_loop.run_until_complete( - self.elis_api_client.update_annotation(annotation_id, data) - ) + return self._run_coroutine(self.elis_api_client.update_annotation(annotation_id, data)) def update_part_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#update-part-of-an-annotation.""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.update_part_annotation(annotation_id, data) ) @@ -392,42 +382,34 @@ def bulk_update_annotation_data( self, annotation_id: int, operations: List[Dict[str, Any]] ) -> None: """https://elis.rossum.ai/api/docs/#bulk-update-annotation-data""" - self.event_loop.run_until_complete( + self._run_coroutine( self.elis_api_client.bulk_update_annotation_data(annotation_id, operations) ) def confirm_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#confirm-annotation""" - self.event_loop.run_until_complete(self.elis_api_client.confirm_annotation(annotation_id)) + self._run_coroutine(self.elis_api_client.confirm_annotation(annotation_id)) def create_new_annotation(self, data: dict[str, Any]) -> Annotation: """https://elis.rossum.ai/api/docs/#create-an-annotation""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_annotation(data)) + return self._run_coroutine(self.elis_api_client.create_new_annotation(data)) def delete_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#switch-to-deleted""" - return self.event_loop.run_until_complete( - self.elis_api_client.delete_annotation(annotation_id) - ) + return self._run_coroutine(self.elis_api_client.delete_annotation(annotation_id)) def cancel_annotation(self, annotation_id: int) -> None: """https://elis.rossum.ai/api/docs/#cancel-annotation""" - return self.event_loop.run_until_complete( - self.elis_api_client.cancel_annotation(annotation_id) - ) + return self._run_coroutine(self.elis_api_client.cancel_annotation(annotation_id)) # ##### DOCUMENTS ##### def retrieve_document(self, document_id: int) -> Document: """https://elis.rossum.ai/api/docs/#retrieve-a-document""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_document(document_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_document(document_id)) def retrieve_document_content(self, document_id: int) -> bytes: """https://elis.rossum.ai/api/docs/#document-content""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_document_content(document_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_document_content(document_id)) def create_new_document( self, @@ -437,7 +419,7 @@ def create_new_document( parent: Optional[str] = None, ) -> Document: """https://elis.rossum.ai/api/docs/#create-document""" - return self.event_loop.run_until_complete( + return self._run_coroutine( self.elis_api_client.create_new_document(file_name, file_data, metadata, parent) ) @@ -452,19 +434,15 @@ def list_all_workspaces( def retrieve_workspace(self, workspace_id: int) -> Workspace: """https://elis.rossum.ai/api/docs/#retrieve-a-workspace.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_workspace(workspace_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_workspace(workspace_id)) def create_new_workspace(self, data: Dict[str, Any]) -> Workspace: """https://elis.rossum.ai/api/docs/#create-a-new-workspace.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_workspace(data)) + return self._run_coroutine(self.elis_api_client.create_new_workspace(data)) def delete_workspace(self, workspace_id: int) -> None: """https://elis.rossum.ai/api/docs/#retrieve-a-workspace.""" - return self.event_loop.run_until_complete( - self.elis_api_client.delete_workspace(workspace_id) - ) + return self._run_coroutine(self.elis_api_client.delete_workspace(workspace_id)) # ##### INBOX ##### def create_new_inbox( @@ -472,7 +450,7 @@ def create_new_inbox( data: Dict[str, Any], ) -> Inbox: """https://elis.rossum.ai/api/docs/#create-a-new-inbox.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_inbox(data)) + return self._run_coroutine(self.elis_api_client.create_new_inbox(data)) # ##### EMAIL TEMPLATES ##### def list_all_email_templates( @@ -487,15 +465,11 @@ def list_all_email_templates( def retrieve_email_template(self, email_template_id: int) -> EmailTemplate: """https://elis.rossum.ai/api/docs/#retrieve-an-email-template-object.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_email_template(email_template_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_email_template(email_template_id)) def create_new_email_template(self, data: Dict[str, Any]) -> EmailTemplate: """https://elis.rossum.ai/api/docs/#create-new-email-template-object.""" - return self.event_loop.run_until_complete( - self.elis_api_client.create_new_email_template(data) - ) + return self._run_coroutine(self.elis_api_client.create_new_email_template(data)) # ##### CONNECTORS ##### def list_all_connectors( @@ -508,13 +482,11 @@ def list_all_connectors( def retrieve_connector(self, connector_id: int) -> Connector: """https://elis.rossum.ai/api/docs/#retrieve-a-connector.""" - return self.event_loop.run_until_complete( - self.elis_api_client.retrieve_connector(connector_id) - ) + return self._run_coroutine(self.elis_api_client.retrieve_connector(connector_id)) def create_new_connector(self, data: Dict[str, Any]) -> Connector: """https://elis.rossum.ai/api/docs/#create-a-new-connector.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_connector(data)) + return self._run_coroutine(self.elis_api_client.create_new_connector(data)) # ##### HOOKS ##### def list_all_hooks( @@ -527,21 +499,19 @@ def list_all_hooks( def retrieve_hook(self, hook_id: int) -> Hook: """https://elis.rossum.ai/api/docs/#retrieve-a-hook.""" - return self.event_loop.run_until_complete(self.elis_api_client.retrieve_hook(hook_id)) + return self._run_coroutine(self.elis_api_client.retrieve_hook(hook_id)) def create_new_hook(self, data: Dict[str, Any]) -> Hook: """https://elis.rossum.ai/api/docs/#create-a-new-hook.""" - return self.event_loop.run_until_complete(self.elis_api_client.create_new_hook(data)) + return self._run_coroutine(self.elis_api_client.create_new_hook(data)) def update_part_hook(self, hook_id: int, data: Dict[str, Any]) -> Hook: """https://elis.rossum.ai/api/docs/#update-part-of-a-hook""" - return self.event_loop.run_until_complete( - self.elis_api_client.update_part_hook(hook_id, data) - ) + return self._run_coroutine(self.elis_api_client.update_part_hook(hook_id, data)) def delete_hook(self, hook_id: int) -> None: """https://elis.rossum.ai/api/docs/#delete-a-hook""" - return self.event_loop.run_until_complete(self.elis_api_client.delete_hook(hook_id)) + return self._run_coroutine(self.elis_api_client.delete_hook(hook_id)) # ##### USER ROLES ##### def list_all_user_roles( @@ -562,17 +532,13 @@ def request_json(self, method: str, *args, **kwargs) -> Dict[str, Any]: """Use to perform requests to seldomly used or experimental endpoints that do not have direct support in the client and return JSON. """ - return self.event_loop.run_until_complete( - self.elis_api_client.request_json(method, *args, **kwargs) - ) + return self._run_coroutine(self.elis_api_client.request_json(method, *args, **kwargs)) def request(self, method: str, *args, **kwargs) -> httpx.Response: """Use to perform requests to seldomly used or experimental endpoints that do not have direct support in the client and return the raw response. """ - return self.event_loop.run_until_complete( - self.elis_api_client.request(method, *args, **kwargs) - ) + return self._run_coroutine(self.elis_api_client.request(method, *args, **kwargs)) def get_token(self, refresh: bool = False) -> str: """Returns the current token. Authentication is done automatically if needed. @@ -582,4 +548,4 @@ def get_token(self, refresh: bool = False) -> str: refresh force refreshing the token """ - return self.event_loop.run_until_complete(self.elis_api_client.get_token(refresh)) + return self._run_coroutine(self.elis_api_client.get_token(refresh)) diff --git a/tests/elis_api_client/test_client_sync.py b/tests/elis_api_client/test_client_sync.py index 9214e12..22b5955 100644 --- a/tests/elis_api_client/test_client_sync.py +++ b/tests/elis_api_client/test_client_sync.py @@ -1,42 +1,9 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch - import httpx -import pytest - -from rossum_api import ElisAPIClientSync -from rossum_api.elis_api_client_sync import AsyncRuntimeError class TestClientSync: - def test_new_event_loop(self): - with patch("asyncio.get_running_loop", side_effect=RuntimeError()), patch( - "asyncio.new_event_loop" - ) as new_event_loop_mock: - ElisAPIClientSync("", "", None) - assert new_event_loop_mock.called - - def test_existing_event_loop_not_running(self): - event_loop = MagicMock() - event_loop.is_running = MagicMock(return_value=False) - with patch("asyncio.get_running_loop", return_value=event_loop), patch( - "asyncio.new_event_loop" - ) as new_event_loop_mock: - ElisAPIClientSync("", "", None) - assert not new_event_loop_mock.called - - def test_existing_event_loop_running(self): - event_loop = MagicMock() - event_loop.is_running = MagicMock(return_value=True) - with patch("asyncio.get_running_loop", return_value=event_loop), patch( - "asyncio.new_event_loop" - ) as new_event_loop_mock: - with pytest.raises(AsyncRuntimeError): - ElisAPIClientSync("", "", None) - - assert not new_event_loop_mock.called - def test_request_paginated(self, elis_client_sync, mock_generator): client, http_client = elis_client_sync http_client.fetch_all_by_url.return_value = mock_generator({"some": "json"}) diff --git a/tests/elis_api_client/test_queues.py b/tests/elis_api_client/test_queues.py index 49a5a8c..18d6c8c 100644 --- a/tests/elis_api_client/test_queues.py +++ b/tests/elis_api_client/test_queues.py @@ -257,7 +257,7 @@ async def test_export_annotations_to_file(self, elis_client, mock_file_read): qid = 123 export_format = "xml" - result = [] + result = b"" async for a in client.export_annotations_to_file( queue_id=qid, export_format=export_format ): @@ -266,8 +266,7 @@ async def test_export_annotations_to_file(self, elis_client, mock_file_read): http_client.export.assert_called_with(Resource.Queue, qid, export_format) with open("tests/data/annotation_export.xml", "rb") as fp: - for i, line in enumerate(fp.read()): - assert result[i] == line + assert result == fp.read() class TestQueuesSync: @@ -389,12 +388,11 @@ def test_export_annotations_to_file(self, elis_client_sync, mock_file_read): qid = 123 export_format = "xml" - result = [] + result = b"" for a in client.export_annotations_to_file(queue_id=qid, export_format=export_format): result += a http_client.export.assert_called_with(Resource.Queue, qid, export_format) with open("tests/data/annotation_export.xml", "rb") as fp: - for i, line in enumerate(fp.read()): - assert result[i] == line + assert result == fp.read()