Skip to content

Commit 5445a3e

Browse files
authored
Merge pull request #2147 from shaangill025/issue_2111
Fix: messages stuck in mediator
2 parents 70b2831 + c900e6b commit 5445a3e

File tree

6 files changed

+273
-13
lines changed

6 files changed

+273
-13
lines changed

aries_cloudagent/admin/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def __init__(
124124
self._profile = weakref.ref(profile)
125125
self._send = send
126126

127-
async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus:
127+
async def send_outbound(
128+
self, message: OutboundMessage, **kwargs
129+
) -> OutboundSendStatus:
128130
"""
129131
Send outbound message.
130132

aries_cloudagent/core/dispatcher.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515

1616
from aiohttp.web import HTTPException
1717

18-
1918
from ..connections.models.conn_record import ConnRecord
2019
from ..core.profile import Profile
2120
from ..messaging.agent_message import AgentMessage
2221
from ..messaging.base_message import BaseMessage
2322
from ..messaging.error import MessageParseError
2423
from ..messaging.models.base import BaseModelError
2524
from ..messaging.request_context import RequestContext
26-
from ..messaging.responder import BaseResponder
25+
from ..messaging.responder import BaseResponder, SKIP_ACTIVE_CONN_CHECK_MSG_TYPES
2726
from ..messaging.util import datetime_now
2827
from ..protocols.connections.v1_0.manager import ConnectionManager
2928
from ..protocols.problem_report.v1_0.message import ProblemReport
@@ -377,7 +376,9 @@ async def create_outbound(
377376

378377
return await super().create_outbound(message, **kwargs)
379378

380-
async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus:
379+
async def send_outbound(
380+
self, message: OutboundMessage, **kwargs
381+
) -> OutboundSendStatus:
381382
"""
382383
Send outbound message.
383384
@@ -388,6 +389,23 @@ async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus:
388389
if not context:
389390
raise RuntimeError("weakref to context has expired")
390391

392+
msg_type = kwargs.get("message_type")
393+
msg_id = kwargs.get("message_id")
394+
395+
if (
396+
message.connection_id
397+
and msg_type
398+
and msg_type not in SKIP_ACTIVE_CONN_CHECK_MSG_TYPES
399+
and not await super().conn_rec_active_state_check(
400+
profile=context.profile,
401+
connection_id=message.connection_id,
402+
)
403+
):
404+
raise RuntimeError(
405+
f"Connection {message.connection_id} is not ready"
406+
" which is required for sending outbound"
407+
f" message {msg_id} of type {msg_type}."
408+
)
391409
return await self._send(context.profile, message, self._inbound_message)
392410

393411
async def send_webhook(self, topic: str, payload: dict):

aries_cloudagent/core/tests/test_dispatcher.py

Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from marshmallow import EXCLUDE
88

9+
from ...cache.base import BaseCache
10+
from ...cache.in_memory import InMemoryCache
911
from ...config.injection_context import InjectionContext
1012
from ...core.event_bus import EventBus
1113
from ...core.in_memory import InMemoryProfile
@@ -413,12 +415,85 @@ async def test_create_send_outbound(self):
413415
profile,
414416
settings={"timing.enabled": True},
415417
)
418+
registry = profile.inject(ProtocolRegistry)
419+
registry.register_message_types(
420+
{
421+
pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage
422+
for pfx in DIDCommPrefix
423+
}
424+
)
416425
message = StubAgentMessage()
417426
responder = test_module.DispatcherResponder(context, message, None)
418-
outbound_message = await responder.create_outbound(message)
419-
with async_mock.patch.object(responder, "_send", async_mock.AsyncMock()):
427+
outbound_message = await responder.create_outbound(
428+
json.dumps(message.serialize())
429+
)
430+
with async_mock.patch.object(
431+
responder, "_send", async_mock.AsyncMock()
432+
), async_mock.patch.object(
433+
test_module.BaseResponder,
434+
"conn_rec_active_state_check",
435+
async_mock.AsyncMock(return_value=True),
436+
):
420437
await responder.send_outbound(outbound_message)
421438

439+
async def test_create_send_outbound_with_msg_attrs(self):
440+
profile = make_profile()
441+
context = RequestContext(
442+
profile,
443+
settings={"timing.enabled": True},
444+
)
445+
registry = profile.inject(ProtocolRegistry)
446+
registry.register_message_types(
447+
{
448+
pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage
449+
for pfx in DIDCommPrefix
450+
}
451+
)
452+
message = StubAgentMessage()
453+
responder = test_module.DispatcherResponder(context, message, None)
454+
outbound_message = await responder.create_outbound(message)
455+
with async_mock.patch.object(
456+
responder, "_send", async_mock.AsyncMock()
457+
), async_mock.patch.object(
458+
test_module.BaseResponder,
459+
"conn_rec_active_state_check",
460+
async_mock.AsyncMock(return_value=True),
461+
):
462+
await responder.send_outbound(
463+
message=outbound_message,
464+
message_type=message._message_type,
465+
message_id=message._id,
466+
)
467+
468+
async def test_create_send_outbound_with_msg_attrs_x(self):
469+
profile = make_profile()
470+
context = RequestContext(
471+
profile,
472+
settings={"timing.enabled": True},
473+
)
474+
registry = profile.inject(ProtocolRegistry)
475+
registry.register_message_types(
476+
{
477+
pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage
478+
for pfx in DIDCommPrefix
479+
}
480+
)
481+
message = StubAgentMessage()
482+
responder = test_module.DispatcherResponder(context, message, None)
483+
outbound_message = await responder.create_outbound(message)
484+
outbound_message.connection_id = "123"
485+
with async_mock.patch.object(
486+
test_module.BaseResponder,
487+
"conn_rec_active_state_check",
488+
async_mock.AsyncMock(return_value=False),
489+
):
490+
with self.assertRaises(RuntimeError):
491+
await responder.send_outbound(
492+
message=outbound_message,
493+
message_type=message._message_type,
494+
message_id=message._id,
495+
)
496+
422497
async def test_create_send_webhook(self):
423498
profile = make_profile()
424499
context = RequestContext(profile)
@@ -427,16 +502,81 @@ async def test_create_send_webhook(self):
427502
with pytest.deprecated_call():
428503
await responder.send_webhook("topic", {"pay": "load"})
429504

505+
async def test_conn_rec_active_state_check_a(self):
506+
profile = make_profile()
507+
profile.context.injector.bind_instance(BaseCache, InMemoryCache())
508+
context = RequestContext(profile)
509+
message = StubAgentMessage()
510+
responder = test_module.DispatcherResponder(context, message, None)
511+
with async_mock.patch.object(
512+
test_module.ConnRecord, "retrieve_by_id", async_mock.AsyncMock()
513+
) as mock_conn_ret_by_id:
514+
conn_rec = test_module.ConnRecord()
515+
conn_rec.state = test_module.ConnRecord.State.COMPLETED
516+
mock_conn_ret_by_id.return_value = conn_rec
517+
check_flag = await responder.conn_rec_active_state_check(
518+
profile,
519+
"conn-id",
520+
)
521+
assert check_flag
522+
check_flag = await responder.conn_rec_active_state_check(
523+
profile,
524+
"conn-id",
525+
)
526+
assert check_flag
527+
528+
async def test_conn_rec_active_state_check_b(self):
529+
profile = make_profile()
530+
profile.context.injector.bind_instance(BaseCache, InMemoryCache())
531+
profile.context.injector.bind_instance(
532+
EventBus, async_mock.MagicMock(notify=async_mock.AsyncMock())
533+
)
534+
context = RequestContext(profile)
535+
message = StubAgentMessage()
536+
responder = test_module.DispatcherResponder(context, message, None)
537+
with async_mock.patch.object(
538+
test_module.ConnRecord, "retrieve_by_id", async_mock.AsyncMock()
539+
) as mock_conn_ret_by_id:
540+
conn_rec_a = test_module.ConnRecord()
541+
conn_rec_a.state = test_module.ConnRecord.State.REQUEST
542+
conn_rec_b = test_module.ConnRecord()
543+
conn_rec_b.state = test_module.ConnRecord.State.COMPLETED
544+
mock_conn_ret_by_id.side_effect = [conn_rec_a, conn_rec_b]
545+
check_flag = await responder.conn_rec_active_state_check(
546+
profile,
547+
"conn-id",
548+
)
549+
assert check_flag
550+
430551
async def test_create_enc_outbound(self):
431552
profile = make_profile()
432553
context = RequestContext(profile)
433-
message = b"abc123xyz7890000"
554+
message = StubAgentMessage()
434555
responder = test_module.DispatcherResponder(context, message, None)
435556
with async_mock.patch.object(
436557
responder, "send_outbound", async_mock.AsyncMock()
437558
) as mock_send_outbound:
438559
await responder.send(message)
439560
assert mock_send_outbound.called_once()
561+
msg_json = json.dumps(StubAgentMessage().serialize())
562+
message = msg_json.encode("utf-8")
563+
with async_mock.patch.object(
564+
responder, "send_outbound", async_mock.AsyncMock()
565+
) as mock_send_outbound:
566+
await responder.send(message)
567+
568+
message = StubAgentMessage()
569+
with async_mock.patch.object(
570+
responder, "send_outbound", async_mock.AsyncMock()
571+
) as mock_send_outbound:
572+
await responder.send_reply(message)
573+
assert mock_send_outbound.called_once()
574+
575+
message = json.dumps(StubAgentMessage().serialize())
576+
with async_mock.patch.object(
577+
responder, "send_outbound", async_mock.AsyncMock()
578+
) as mock_send_outbound:
579+
await responder.send_reply(message)
440580

441581
async def test_expired_context_x(self):
442582
def _smaller_scope():

aries_cloudagent/messaging/responder.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,30 @@
44
The responder is provided to message handlers to enable them to send a new message
55
in response to the message being handled.
66
"""
7+
import asyncio
8+
import json
79

810
from abc import ABC, abstractmethod
9-
import json
10-
from typing import Sequence, Union
11+
from typing import Sequence, Union, Optional, Tuple
1112

13+
from ..cache.base import BaseCache
1214
from ..connections.models.connection_target import ConnectionTarget
15+
from ..connections.models.conn_record import ConnRecord
1316
from ..core.error import BaseError
17+
from ..core.profile import Profile
1418
from ..transport.outbound.message import OutboundMessage
1519

1620
from .base_message import BaseMessage
1721
from ..transport.outbound.status import OutboundSendStatus
1822

23+
SKIP_ACTIVE_CONN_CHECK_MSG_TYPES = [
24+
"didexchange/1.0/request",
25+
"didexchange/1.0/response",
26+
"connections/1.0/invitation",
27+
"connections/1.0/request",
28+
"connections/1.0/response",
29+
]
30+
1931

2032
class ResponderError(BaseError):
2133
"""Responder error."""
@@ -79,7 +91,18 @@ async def send(
7991
) -> OutboundSendStatus:
8092
"""Convert a message to an OutboundMessage and send it."""
8193
outbound = await self.create_outbound(message, **kwargs)
82-
return await self.send_outbound(outbound)
94+
if isinstance(message, BaseMessage):
95+
msg_type = message._message_type
96+
msg_id = message._id
97+
else:
98+
msg_dict = json.loads(message)
99+
msg_type = msg_dict.get("@type")
100+
msg_id = msg_dict.get("@id")
101+
return await self.send_outbound(
102+
message=outbound,
103+
message_type=msg_type,
104+
message_id=msg_id,
105+
)
83106

84107
async def send_reply(
85108
self,
@@ -109,10 +132,59 @@ async def send_reply(
109132
target=target,
110133
target_list=target_list,
111134
)
112-
return await self.send_outbound(outbound)
135+
if isinstance(message, BaseMessage):
136+
msg_type = message._message_type
137+
msg_id = message._id
138+
else:
139+
msg_dict = json.loads(message)
140+
msg_type = msg_dict.get("@type")
141+
msg_id = msg_dict.get("@id")
142+
return await self.send_outbound(
143+
message=outbound, message_type=msg_type, message_id=msg_id
144+
)
145+
146+
async def conn_rec_active_state_check(
147+
self, profile: Profile, connection_id: str, timeout: int = 7
148+
) -> bool:
149+
"""Check if the connection record is ready for sending outbound message."""
150+
151+
async def _wait_for_state() -> Tuple[bool, Optional[str]]:
152+
while True:
153+
async with profile.session() as session:
154+
conn_record = await ConnRecord.retrieve_by_id(
155+
session, connection_id
156+
)
157+
if conn_record.is_ready:
158+
# if ConnRecord.State.get(conn_record.state) in (
159+
# ConnRecord.State.COMPLETED,
160+
# ):
161+
return (True, conn_record.state)
162+
await asyncio.sleep(1)
163+
164+
try:
165+
cache_key = f"conn_rec_state::{connection_id}"
166+
connection_state = None
167+
cache = profile.inject_or(BaseCache)
168+
if cache:
169+
connection_state = await cache.get(cache_key)
170+
if connection_state and ConnRecord.State.get(connection_state) in (
171+
ConnRecord.State.COMPLETED,
172+
ConnRecord.State.RESPONSE,
173+
):
174+
return True
175+
check_flag, connection_state = await asyncio.wait_for(
176+
_wait_for_state(), timeout
177+
)
178+
if cache and connection_state:
179+
await cache.set(cache_key, connection_state)
180+
return check_flag
181+
except asyncio.TimeoutError:
182+
return False
113183

114184
@abstractmethod
115-
async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus:
185+
async def send_outbound(
186+
self, message: OutboundMessage, **kwargs
187+
) -> OutboundSendStatus:
116188
"""
117189
Send an outbound message.
118190
@@ -152,7 +224,9 @@ async def send_reply(
152224
self.messages.append((message, kwargs))
153225
return OutboundSendStatus.QUEUED_FOR_DELIVERY
154226

155-
async def send_outbound(self, message: OutboundMessage) -> OutboundSendStatus:
227+
async def send_outbound(
228+
self, message: OutboundMessage, **kwargs
229+
) -> OutboundSendStatus:
156230
"""Send an outbound message."""
157231
self.messages.append((message, None))
158232
return OutboundSendStatus.QUEUED_FOR_DELIVERY

aries_cloudagent/protocols/connections/v1_0/routes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from ....admin.request_context import AdminRequestContext
1717
from ....connections.models.conn_record import ConnRecord, ConnRecordSchema
18+
from ....cache.base import BaseCache
1819
from ....messaging.models.base import BaseModelError
1920
from ....messaging.models.openapi import OpenAPISchema
2021
from ....messaging.valid import (
@@ -739,6 +740,9 @@ async def connections_remove(request: web.BaseRequest):
739740
async with profile.session() as session:
740741
connection = await ConnRecord.retrieve_by_id(session, connection_id)
741742
await connection.delete_record(session)
743+
cache = session.inject_or(BaseCache)
744+
if cache:
745+
await cache.clear(f"conn_rec_state::{connection_id}")
742746
except StorageNotFoundError as err:
743747
raise web.HTTPNotFound(reason=err.roll_up) from err
744748
except StorageError as err:

0 commit comments

Comments
 (0)