From 28091f035ed6d7a063e847a039e4495bc2aa840d Mon Sep 17 00:00:00 2001 From: enrique Date: Fri, 25 Oct 2024 16:14:08 +0200 Subject: [PATCH] feat: support connect and reconnect joining rooms --- payments_py/ai_query_api.py | 16 +++++++-------- payments_py/nvm_backend.py | 39 ++++++++++++++++++++++++++----------- payments_py/payments.py | 1 - tests/protocol_test.py | 12 ++++++------ 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/payments_py/ai_query_api.py b/payments_py/ai_query_api.py index c142b58..72afa78 100644 --- a/payments_py/ai_query_api.py +++ b/payments_py/ai_query_api.py @@ -50,17 +50,15 @@ async def subscribe(self, callback: Any, join_account_room: bool = True, join_ag subscribe_event_types (Optional[List[str]]): The event types to subscribe to. get_pending_events_on_subscribe (bool): If True, it will get the pending events on subscribe. """ - await self._subscribe(callback, join_account_room, join_agent_rooms, subscribe_event_types) - print('query-api:: Connected to the server') - if get_pending_events_on_subscribe: - try: - if(get_pending_events_on_subscribe and join_agent_rooms): - await self._emit_step_events(AgentExecutionStatus.Pending, join_agent_rooms) - except Exception as e: - print('query-api:: Unable to get pending events', e) + self.callback = callback + self.join_account_room = join_account_room + self.join_agent_rooms = join_agent_rooms + self.subscribe_event_types = subscribe_event_types + self.get_pending_events_on_subscribe = get_pending_events_on_subscribe + + await self.connect_socket() await asyncio.Event().wait() - def create_task(self, did: str, task: Any): """ Subscribers can create an AI Task for an Agent. The task must contain the input query that will be used by the AI Agent. diff --git a/payments_py/nvm_backend.py b/payments_py/nvm_backend.py index cbfa380..59c9c2a 100644 --- a/payments_py/nvm_backend.py +++ b/payments_py/nvm_backend.py @@ -3,6 +3,7 @@ import socketio import jwt from typing import Optional, Dict, List, Any, Union +import asyncio from payments_py.data_models import AgentExecutionStatus, ServiceTokenResultDto from payments_py.environments import Environment @@ -37,9 +38,15 @@ class NVMBackendApi: def __init__(self, opts: BackendApiOptions): self.opts = opts self.socket_client = sio + self.connected_event = asyncio.Event() + self.socket_client.on('connect', self.connect_handler) self.user_room_id = None self.has_key = False - + self.callback = None + self.join_account_room = None + self.join_agent_rooms = None + self.subscribe_event_types = None + default_headers = { 'Accept': 'application/json', **(opts.headers or {}), @@ -76,7 +83,7 @@ def __init__(self, opts: BackendApiOptions): self.opts.backend_host = backend_url except Exception as error: raise ValueError(f"Invalid URL: {self.opts.backend_host} - {str(error)}") - + async def connect_socket(self): if not self.has_key: raise ValueError('Unable to subscribe to the server because a key was not provided') @@ -99,28 +106,38 @@ async def disconnect_socket(self): if self.socket_client and self.socket_client.connected: self.socket_client.disconnect() - async def _subscribe(self, callback, join_account_room: bool = True, join_agent_rooms: Optional[Union[str, List[str]]] = None, subscribe_event_types: Optional[List[str]] = None): - if not join_account_room and not join_agent_rooms: + async def connect_handler(self): + while self.socket_client.connected == False: + print('Connecting...') + await asyncio.sleep(1) + await self._subscribe() + if self.get_pending_events_on_subscribe: + try: + if(self.get_pending_events_on_subscribe and self.join_agent_rooms): + await self._emit_step_events(AgentExecutionStatus.Pending, self.join_agent_rooms) + except Exception as e: + print('query-api:: Unable to get pending events', e) + + async def _subscribe(self): + if not self.join_account_room and not self.join_agent_rooms: raise ValueError('No rooms to join in configuration') - await self.connect_socket() if not self.socket_client.connected: raise ConnectionError('Failed to connect to the WebSocket server.') async def event_handler(data): parsed_data = json.loads(data) - await callback(parsed_data) + await self.callback(parsed_data) - await self.join_room(join_account_room, join_agent_rooms) + await self.join_room(self.join_account_room, self.join_agent_rooms) - if subscribe_event_types: - for event in subscribe_event_types: + if self.subscribe_event_types: + for event in self.subscribe_event_types: print(f"nvm-backend:: Subscribing to event: {event}") self.socket_client.on(event, event_handler) else: self.socket_client.on('step-updated', event_handler) - + async def _emit_step_events(self, status: AgentExecutionStatus = AgentExecutionStatus.Pending, dids: List[str] = []): - await self.connect_socket() message = { "status": status.value, "dids": dids } print(f"nvm-backend:: Emitting step: {json.dumps(message)}") await self.socket_client.emit(event='_emit-steps', data=json.dumps(message)) diff --git a/payments_py/payments.py b/payments_py/payments.py index d8cbb96..0f93332 100644 --- a/payments_py/payments.py +++ b/payments_py/payments.py @@ -508,7 +508,6 @@ def create_agent(self, plan_did: str, name: str, description: str, service_charg query_protocol_version, service_host) - def order_plan(self, plan_did: str, agreementId: Optional[str] = None) -> OrderPlanResultDto: """ Orders a Payment Plan. The user needs to have enough balance in the token selected by the owner of the Payment Plan. diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 90694ab..75b20ff 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -176,14 +176,14 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu pass # @pytest.mark.asyncio(loop_scope="session") -# async def test_AI_send_task(ai_query_api_build_fixture): -# builder = ai_query_api_build_fixture +# async def test_AI_send_task(ai_query_api_subscriber_fixture): +# builder = ai_query_api_subscriber_fixture # task = builder.ai_protocol.create_task('did:nv:7d86045034ea8a14c133c487374a175c56a9c6144f6395581435bc7f1dc9e0cc', -# {'query': 'https://www.youtube.com/watch?v=SB7eoaVw4Sk', 'name': 'Summarize video'}) +# {'query': 'https://www.youtube.com/watch?v=0q_BrgesfF4', 'name': 'Summarize video'}) # print('Task created:', task.json()) # @pytest.mark.asyncio(loop_scope="session") -# async def test_AI_send_task2(ai_query_api_build_fixture): -# builder = ai_query_api_build_fixture -# task = builder.ai_protocol.get_task_with_steps(did='did:nv:a8983b06c0f25fb4064fc61d6527c84ca1813e552bfad5fa1c974caa3c5ccf49', task_id='task-cd5a90e6-688f-45a3-a299-1845d10db625') +# async def test_AI_send_task2(ai_query_api_subscriber_fixture): +# builder = ai_query_api_subscriber_fixture +# task = builder.ai_protocol.get_task_with_steps(did='did:nv:7d86045034ea8a14c133c487374a175c56a9c6144f6395581435bc7f1dc9e0cc', task_id='task-6b16b12e-3aa2-43c3-a756-a150b07665e2') # print('Task result:', task.json()) \ No newline at end of file