Skip to content

Commit

Permalink
Merge pull request #2748 from dbluhm/fix/didexchange-1.1
Browse files Browse the repository at this point in the history
fix(credo-interop): various didexchange and did:peer related fixes
  • Loading branch information
swcurran committed Mar 26, 2024
2 parents f1bb749 + 19547fc commit 41d8024
Show file tree
Hide file tree
Showing 44 changed files with 1,066 additions and 1,265 deletions.
7 changes: 5 additions & 2 deletions aries_cloudagent/askar/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,12 @@ def store(self) -> Store:
@property
def is_transaction(self) -> bool:
"""Check if the session supports commit and rollback operations."""
if self._handle:
if self._handle: # opened
return self._handle.is_transaction
return self._opener.is_transaction
if self._opener: # opening
return self._opener.is_transaction

raise ProfileError("Session not open")

async def _setup(self):
"""Create the session or transaction connection, if needed."""
Expand Down
4 changes: 0 additions & 4 deletions aries_cloudagent/config/default_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ..core.protocol_registry import ProtocolRegistry
from ..protocols.actionmenu.v1_0.base_service import BaseMenuService
from ..protocols.actionmenu.v1_0.driver_service import DriverMenuService
from ..protocols.didcomm_prefix import DIDCommPrefix
from ..protocols.introduction.v0_1.base_service import BaseIntroductionService
from ..protocols.introduction.v0_1.demo_service import DemoIntroductionService
from ..resolver.did_resolver import DIDResolver
Expand Down Expand Up @@ -66,9 +65,6 @@ async def build_context(self) -> InjectionContext:
await self.bind_providers(context)
await self.load_plugins(context)

# Set DIDComm prefix
DIDCommPrefix.set(context.settings)

return context

async def bind_providers(self, context: InjectionContext):
Expand Down
5 changes: 4 additions & 1 deletion aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ async def create_did_peer_4(
async with self._profile.session() as session:
wallet = session.inject(BaseWallet)
key = await wallet.create_key(ED25519)
key_spec = KeySpec_DP4(multikey=self._key_info_to_multikey(key))
key_spec = KeySpec_DP4(
multikey=self._key_info_to_multikey(key),
relationships=["authentication", "keyAgreement"],
)
input_doc = input_doc_from_keys_and_services(
keys=[key_spec], services=services
)
Expand Down
43 changes: 11 additions & 32 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
)
from ...protocols.connections.v1_0.messages.connection_request import ConnectionRequest
from ...protocols.didcomm_prefix import DIDCommPrefix
from ...protocols.didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO
from ...protocols.didexchange.v1_0.message_types import (
ARIES_PROTOCOL as DIDEX_1_1,
DIDEX_1_0,
)
from ...protocols.didexchange.v1_0.messages.request import DIDXRequest
from ...protocols.out_of_band.v1_0.messages.invitation import (
InvitationMessage as OOBInvitation,
Expand All @@ -43,27 +46,7 @@ class Meta:

schema_class = "MaybeStoredConnRecordSchema"

class Protocol(Enum):
"""Supported Protocols for Connection."""

RFC_0160 = CONN_PROTO
RFC_0023 = DIDX_PROTO

@classmethod
def get(cls, label: Union[str, "ConnRecord.Protocol"]):
"""Get aries protocol enum for label."""
if isinstance(label, str):
for proto in ConnRecord.Protocol:
if label in proto.value:
return proto
elif isinstance(label, ConnRecord.Protocol):
return label
return None

@property
def aries_protocol(self):
"""Return used connection protocol."""
return self.value
SUPPORTED_PROTOCOLS = (CONN_PROTO, DIDEX_1_0, DIDEX_1_1)

class Role(Enum):
"""RFC 160 (inviter, invitee) = RFC 23 (responder, requester)."""
Expand Down Expand Up @@ -211,7 +194,7 @@ def __init__(
invitation_mode: Optional[str] = None,
alias: Optional[str] = None,
their_public_did: Optional[str] = None,
connection_protocol: Union[str, "ConnRecord.Protocol", None] = None,
connection_protocol: Optional[str] = None,
# from state: formalism for base_record.from_storage()
rfc23_state: Optional[str] = None,
# for backward compat with old records
Expand Down Expand Up @@ -244,13 +227,9 @@ def __init__(
self.alias = alias
self.their_public_did = their_public_did
self.connection_protocol = (
ConnRecord.Protocol.get(connection_protocol).aries_protocol
if isinstance(connection_protocol, str)
else (
None
if connection_protocol is None
else connection_protocol.aries_protocol
)
connection_protocol
if connection_protocol in self.SUPPORTED_PROTOCOLS
else None
)

@property
Expand Down Expand Up @@ -684,10 +663,10 @@ class Meta:
)
connection_protocol = fields.Str(
required=False,
validate=validate.OneOf([proto.value for proto in ConnRecord.Protocol]),
validate=validate.OneOf(ConnRecord.SUPPORTED_PROTOCOLS),
metadata={
"description": "Connection protocol used",
"example": ConnRecord.Protocol.RFC_0160.aries_protocol,
"example": "connections/1.0",
},
)
rfc23_state = fields.Str(
Expand Down
17 changes: 0 additions & 17 deletions aries_cloudagent/connections/models/tests/test_conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,6 @@ def setUp(self):
assert self.test_conn_record.state == ConnRecord.State.COMPLETED.rfc160
assert self.test_conn_record.rfc23_state == ConnRecord.State.COMPLETED.rfc23

def test_get_protocol(self):
assert ConnRecord.Protocol.get("test") is None
assert (
ConnRecord.Protocol.get("didexchange/1.0") is ConnRecord.Protocol.RFC_0023
)
assert (
ConnRecord.Protocol.get(ConnRecord.Protocol.RFC_0023)
is ConnRecord.Protocol.RFC_0023
)
assert (
ConnRecord.Protocol.get("connections/1.0") is ConnRecord.Protocol.RFC_0160
)
assert (
ConnRecord.Protocol.get(ConnRecord.Protocol.RFC_0160)
is ConnRecord.Protocol.RFC_0160
)

async def test_get_enums(self):
assert ConnRecord.Role.get("Larry") is None
assert ConnRecord.State.get("a suffusion of yellow") is None
Expand Down
63 changes: 8 additions & 55 deletions aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import asyncio
import logging
import os
from typing import Callable, Coroutine, Optional, Tuple, Union
from typing import Callable, Coroutine, Union
import warnings
import weakref

Expand All @@ -27,12 +27,12 @@
from ..transport.inbound.message import InboundMessage
from ..transport.outbound.message import OutboundMessage
from ..transport.outbound.status import OutboundSendStatus
from ..utils.classloader import DeferLoad
from ..utils.stats import Collector
from ..utils.task_queue import CompletedTask, PendingTask, TaskQueue
from ..utils.tracing import get_timer, trace_event
from .error import ProtocolMinorVersionNotSupported
from .protocol_registry import ProtocolRegistry
from .util import get_version_from_message_type, validate_get_response_version


class ProblemReportParseError(MessageParseError):
Expand Down Expand Up @@ -139,9 +139,7 @@ async def handle_message(
version_warning = None
message = None
try:
(message, warning) = await self.make_message(
profile, inbound_message.payload
)
message = await self.make_message(profile, inbound_message.payload)
except ProblemReportParseError:
pass # avoid problem report recursion
except MessageParseError as e:
Expand All @@ -156,47 +154,6 @@ async def handle_message(
)
if inbound_message.receipt.thread_id:
error_result.assign_thread_id(inbound_message.receipt.thread_id)
# if warning:
# warning_message_type = inbound_message.payload.get("@type")
# if warning == WARNING_DEGRADED_FEATURES:
# self.logger.error(
# f"Sending {WARNING_DEGRADED_FEATURES} problem report, "
# "message type received with a minor version at or higher"
# " than protocol minimum supported and current minor version "
# f"for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version at or "
# "higher than protocol minimum supported and current"
# f" minor version for message_type {warning_message_type}"
# ),
# "code": WARNING_DEGRADED_FEATURES,
# }
# )
# elif warning == WARNING_VERSION_MISMATCH:
# self.logger.error(
# f"Sending {WARNING_VERSION_MISMATCH} problem report, message "
# "type received with a minor version higher than current minor "
# f"version for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version higher"
# " than current minor version for message_type"
# f" {warning_message_type}"
# ),
# "code": WARNING_VERSION_MISMATCH,
# }
# )
# elif warning == WARNING_VERSION_NOT_SUPPORTED:
# raise MessageParseError(
# f"Message type version not supported for {warning_message_type}"
# )
# if version_warning and inbound_message.receipt.thread_id:
# version_warning.assign_thread_id(inbound_message.receipt.thread_id)

trace_event(
self.profile.settings,
Expand Down Expand Up @@ -259,9 +216,7 @@ async def handle_message(
perf_counter=r_time,
)

async def make_message(
self, profile: Profile, parsed_msg: dict
) -> Tuple[BaseMessage, Optional[str]]:
async def make_message(self, profile: Profile, parsed_msg: dict) -> BaseMessage:
"""Deserialize a message dict into the appropriate message instance.
Given a dict describing a message, this method
Expand Down Expand Up @@ -302,11 +257,11 @@ async def make_message(
"a future release. Use https://didcomm.org/ instead.",
)

message_type_rec_version = get_version_from_message_type(message_type)

registry: ProtocolRegistry = self.profile.inject(ProtocolRegistry)
try:
message_cls = registry.resolve_message_class(message_type)
if isinstance(message_cls, DeferLoad):
message_cls = message_cls.resolved
except ProtocolMinorVersionNotSupported as e:
raise MessageParseError(f"Problem parsing message type. {e}")

Expand All @@ -319,10 +274,8 @@ async def make_message(
if "/problem-report" in message_type:
raise ProblemReportParseError("Error parsing problem report message")
raise MessageParseError(f"Error deserializing message: {e}") from e
_, warning = await validate_get_response_version(
profile, message_type_rec_version, message_cls
)
return (instance, warning)

return instance

async def complete(self, timeout: float = 0.1):
"""Wait for pending tasks to complete."""
Expand Down
19 changes: 12 additions & 7 deletions aries_cloudagent/core/oob_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,20 @@ async def clean_finished_oob_record(self, profile: Profile, message: AgentMessag
{"role": OobRecord.ROLE_SENDER},
)

# If the oob record is not multi use and it doesn't contain any attachments
# We can now safely remove the oob record
if not oob_record.multi_use and not oob_record.invitation.requests_attach:
oob_record.state = OobRecord.STATE_DONE
await oob_record.emit_event(session)
await oob_record.delete_record(session)
except Exception:
# If the oob record is not multi use and it doesn't contain any
# attachments, we can now safely remove the oob record
if (
not oob_record.multi_use
and not oob_record.invitation.requests_attach
):
oob_record.state = OobRecord.STATE_DONE
await oob_record.emit_event(session)
await oob_record.delete_record(session)
except StorageNotFoundError:
# It is fine if no oob record is found, Only retrieved for cleanup
pass
except Exception:
LOGGER.warning("Error cleaning up oob record", exc_info=True)

async def find_oob_target_for_outbound_message(
self, profile: Profile, outbound_message: OutboundMessage
Expand Down
4 changes: 1 addition & 3 deletions aries_cloudagent/core/plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ async def load_protocol_version(
mod.MESSAGE_TYPES, version_definition=version_definition
)
if hasattr(mod, "CONTROLLERS"):
protocol_registry.register_controllers(
mod.CONTROLLERS, version_definition=version_definition
)
protocol_registry.register_controllers(mod.CONTROLLERS)
goal_code_registry.register_controllers(mod.CONTROLLERS)

async def load_protocols(self, context: InjectionContext, plugin: ModuleType):
Expand Down
Loading

0 comments on commit 41d8024

Please sign in to comment.