diff --git a/aries_cloudagent/askar/profile.py b/aries_cloudagent/askar/profile.py index e9b84db4e0..27cec91b7c 100644 --- a/aries_cloudagent/askar/profile.py +++ b/aries_cloudagent/askar/profile.py @@ -224,9 +224,12 @@ def store(self) -> Store: @property def is_transaction(self) -> bool: """Check if the session supports commit and rollback operations.""" - if self._handle: + if self._handle: # opened return self._handle.is_transaction - return self._opener.is_transaction + if self._opener: # opening + return self._opener.is_transaction + + raise ProfileError("Session not open") async def _setup(self): """Create the session or transaction connection, if needed.""" diff --git a/aries_cloudagent/config/default_context.py b/aries_cloudagent/config/default_context.py index 2f8c0181fe..a14de29720 100644 --- a/aries_cloudagent/config/default_context.py +++ b/aries_cloudagent/config/default_context.py @@ -10,7 +10,6 @@ from ..core.protocol_registry import ProtocolRegistry from ..protocols.actionmenu.v1_0.base_service import BaseMenuService from ..protocols.actionmenu.v1_0.driver_service import DriverMenuService -from ..protocols.didcomm_prefix import DIDCommPrefix from ..protocols.introduction.v0_1.base_service import BaseIntroductionService from ..protocols.introduction.v0_1.demo_service import DemoIntroductionService from ..resolver.did_resolver import DIDResolver @@ -66,9 +65,6 @@ async def build_context(self) -> InjectionContext: await self.bind_providers(context) await self.load_plugins(context) - # Set DIDComm prefix - DIDCommPrefix.set(context.settings) - return context async def bind_providers(self, context: InjectionContext): diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index 817288d7c8..f771752e46 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -160,7 +160,10 @@ async def create_did_peer_4( async with self._profile.session() as session: wallet = session.inject(BaseWallet) key = await wallet.create_key(ED25519) - key_spec = KeySpec_DP4(multikey=self._key_info_to_multikey(key)) + key_spec = KeySpec_DP4( + multikey=self._key_info_to_multikey(key), + relationships=["authentication", "keyAgreement"], + ) input_doc = input_doc_from_keys_and_services( keys=[key_spec], services=services ) diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index c45089c8ce..da556ed4ab 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -25,7 +25,10 @@ ) from ...protocols.connections.v1_0.messages.connection_request import ConnectionRequest from ...protocols.didcomm_prefix import DIDCommPrefix -from ...protocols.didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO +from ...protocols.didexchange.v1_0.message_types import ( + ARIES_PROTOCOL as DIDEX_1_1, + DIDEX_1_0, +) from ...protocols.didexchange.v1_0.messages.request import DIDXRequest from ...protocols.out_of_band.v1_0.messages.invitation import ( InvitationMessage as OOBInvitation, @@ -43,27 +46,7 @@ class Meta: schema_class = "MaybeStoredConnRecordSchema" - class Protocol(Enum): - """Supported Protocols for Connection.""" - - RFC_0160 = CONN_PROTO - RFC_0023 = DIDX_PROTO - - @classmethod - def get(cls, label: Union[str, "ConnRecord.Protocol"]): - """Get aries protocol enum for label.""" - if isinstance(label, str): - for proto in ConnRecord.Protocol: - if label in proto.value: - return proto - elif isinstance(label, ConnRecord.Protocol): - return label - return None - - @property - def aries_protocol(self): - """Return used connection protocol.""" - return self.value + SUPPORTED_PROTOCOLS = (CONN_PROTO, DIDEX_1_0, DIDEX_1_1) class Role(Enum): """RFC 160 (inviter, invitee) = RFC 23 (responder, requester).""" @@ -211,7 +194,7 @@ def __init__( invitation_mode: Optional[str] = None, alias: Optional[str] = None, their_public_did: Optional[str] = None, - connection_protocol: Union[str, "ConnRecord.Protocol", None] = None, + connection_protocol: Optional[str] = None, # from state: formalism for base_record.from_storage() rfc23_state: Optional[str] = None, # for backward compat with old records @@ -244,13 +227,9 @@ def __init__( self.alias = alias self.their_public_did = their_public_did self.connection_protocol = ( - ConnRecord.Protocol.get(connection_protocol).aries_protocol - if isinstance(connection_protocol, str) - else ( - None - if connection_protocol is None - else connection_protocol.aries_protocol - ) + connection_protocol + if connection_protocol in self.SUPPORTED_PROTOCOLS + else None ) @property @@ -684,10 +663,10 @@ class Meta: ) connection_protocol = fields.Str( required=False, - validate=validate.OneOf([proto.value for proto in ConnRecord.Protocol]), + validate=validate.OneOf(ConnRecord.SUPPORTED_PROTOCOLS), metadata={ "description": "Connection protocol used", - "example": ConnRecord.Protocol.RFC_0160.aries_protocol, + "example": "connections/1.0", }, ) rfc23_state = fields.Str( diff --git a/aries_cloudagent/connections/models/tests/test_conn_record.py b/aries_cloudagent/connections/models/tests/test_conn_record.py index 13bb06df58..125330c70a 100644 --- a/aries_cloudagent/connections/models/tests/test_conn_record.py +++ b/aries_cloudagent/connections/models/tests/test_conn_record.py @@ -35,23 +35,6 @@ def setUp(self): assert self.test_conn_record.state == ConnRecord.State.COMPLETED.rfc160 assert self.test_conn_record.rfc23_state == ConnRecord.State.COMPLETED.rfc23 - def test_get_protocol(self): - assert ConnRecord.Protocol.get("test") is None - assert ( - ConnRecord.Protocol.get("didexchange/1.0") is ConnRecord.Protocol.RFC_0023 - ) - assert ( - ConnRecord.Protocol.get(ConnRecord.Protocol.RFC_0023) - is ConnRecord.Protocol.RFC_0023 - ) - assert ( - ConnRecord.Protocol.get("connections/1.0") is ConnRecord.Protocol.RFC_0160 - ) - assert ( - ConnRecord.Protocol.get(ConnRecord.Protocol.RFC_0160) - is ConnRecord.Protocol.RFC_0160 - ) - async def test_get_enums(self): assert ConnRecord.Role.get("Larry") is None assert ConnRecord.State.get("a suffusion of yellow") is None diff --git a/aries_cloudagent/core/dispatcher.py b/aries_cloudagent/core/dispatcher.py index c0b6e2ad3e..4887af4de5 100644 --- a/aries_cloudagent/core/dispatcher.py +++ b/aries_cloudagent/core/dispatcher.py @@ -7,7 +7,7 @@ import asyncio import logging import os -from typing import Callable, Coroutine, Optional, Tuple, Union +from typing import Callable, Coroutine, Union import warnings import weakref @@ -27,12 +27,12 @@ from ..transport.inbound.message import InboundMessage from ..transport.outbound.message import OutboundMessage from ..transport.outbound.status import OutboundSendStatus +from ..utils.classloader import DeferLoad from ..utils.stats import Collector from ..utils.task_queue import CompletedTask, PendingTask, TaskQueue from ..utils.tracing import get_timer, trace_event from .error import ProtocolMinorVersionNotSupported from .protocol_registry import ProtocolRegistry -from .util import get_version_from_message_type, validate_get_response_version class ProblemReportParseError(MessageParseError): @@ -139,9 +139,7 @@ async def handle_message( version_warning = None message = None try: - (message, warning) = await self.make_message( - profile, inbound_message.payload - ) + message = await self.make_message(profile, inbound_message.payload) except ProblemReportParseError: pass # avoid problem report recursion except MessageParseError as e: @@ -156,47 +154,6 @@ async def handle_message( ) if inbound_message.receipt.thread_id: error_result.assign_thread_id(inbound_message.receipt.thread_id) - # if warning: - # warning_message_type = inbound_message.payload.get("@type") - # if warning == WARNING_DEGRADED_FEATURES: - # self.logger.error( - # f"Sending {WARNING_DEGRADED_FEATURES} problem report, " - # "message type received with a minor version at or higher" - # " than protocol minimum supported and current minor version " - # f"for message_type {warning_message_type}" - # ) - # version_warning = ProblemReport( - # description={ - # "en": ( - # "message type received with a minor version at or " - # "higher than protocol minimum supported and current" - # f" minor version for message_type {warning_message_type}" - # ), - # "code": WARNING_DEGRADED_FEATURES, - # } - # ) - # elif warning == WARNING_VERSION_MISMATCH: - # self.logger.error( - # f"Sending {WARNING_VERSION_MISMATCH} problem report, message " - # "type received with a minor version higher than current minor " - # f"version for message_type {warning_message_type}" - # ) - # version_warning = ProblemReport( - # description={ - # "en": ( - # "message type received with a minor version higher" - # " than current minor version for message_type" - # f" {warning_message_type}" - # ), - # "code": WARNING_VERSION_MISMATCH, - # } - # ) - # elif warning == WARNING_VERSION_NOT_SUPPORTED: - # raise MessageParseError( - # f"Message type version not supported for {warning_message_type}" - # ) - # if version_warning and inbound_message.receipt.thread_id: - # version_warning.assign_thread_id(inbound_message.receipt.thread_id) trace_event( self.profile.settings, @@ -259,9 +216,7 @@ async def handle_message( perf_counter=r_time, ) - async def make_message( - self, profile: Profile, parsed_msg: dict - ) -> Tuple[BaseMessage, Optional[str]]: + async def make_message(self, profile: Profile, parsed_msg: dict) -> BaseMessage: """Deserialize a message dict into the appropriate message instance. Given a dict describing a message, this method @@ -302,11 +257,11 @@ async def make_message( "a future release. Use https://didcomm.org/ instead.", ) - message_type_rec_version = get_version_from_message_type(message_type) - registry: ProtocolRegistry = self.profile.inject(ProtocolRegistry) try: message_cls = registry.resolve_message_class(message_type) + if isinstance(message_cls, DeferLoad): + message_cls = message_cls.resolved except ProtocolMinorVersionNotSupported as e: raise MessageParseError(f"Problem parsing message type. {e}") @@ -319,10 +274,8 @@ async def make_message( if "/problem-report" in message_type: raise ProblemReportParseError("Error parsing problem report message") raise MessageParseError(f"Error deserializing message: {e}") from e - _, warning = await validate_get_response_version( - profile, message_type_rec_version, message_cls - ) - return (instance, warning) + + return instance async def complete(self, timeout: float = 0.1): """Wait for pending tasks to complete.""" diff --git a/aries_cloudagent/core/oob_processor.py b/aries_cloudagent/core/oob_processor.py index d1305aca3a..79ab103459 100644 --- a/aries_cloudagent/core/oob_processor.py +++ b/aries_cloudagent/core/oob_processor.py @@ -57,15 +57,20 @@ async def clean_finished_oob_record(self, profile: Profile, message: AgentMessag {"role": OobRecord.ROLE_SENDER}, ) - # If the oob record is not multi use and it doesn't contain any attachments - # We can now safely remove the oob record - if not oob_record.multi_use and not oob_record.invitation.requests_attach: - oob_record.state = OobRecord.STATE_DONE - await oob_record.emit_event(session) - await oob_record.delete_record(session) - except Exception: + # If the oob record is not multi use and it doesn't contain any + # attachments, we can now safely remove the oob record + if ( + not oob_record.multi_use + and not oob_record.invitation.requests_attach + ): + oob_record.state = OobRecord.STATE_DONE + await oob_record.emit_event(session) + await oob_record.delete_record(session) + except StorageNotFoundError: # It is fine if no oob record is found, Only retrieved for cleanup pass + except Exception: + LOGGER.warning("Error cleaning up oob record", exc_info=True) async def find_oob_target_for_outbound_message( self, profile: Profile, outbound_message: OutboundMessage diff --git a/aries_cloudagent/core/plugin_registry.py b/aries_cloudagent/core/plugin_registry.py index f286e2e404..db25f3c6ce 100644 --- a/aries_cloudagent/core/plugin_registry.py +++ b/aries_cloudagent/core/plugin_registry.py @@ -226,9 +226,7 @@ async def load_protocol_version( mod.MESSAGE_TYPES, version_definition=version_definition ) if hasattr(mod, "CONTROLLERS"): - protocol_registry.register_controllers( - mod.CONTROLLERS, version_definition=version_definition - ) + protocol_registry.register_controllers(mod.CONTROLLERS) goal_code_registry.register_controllers(mod.CONTROLLERS) async def load_protocols(self, context: InjectionContext, plugin: ModuleType): diff --git a/aries_cloudagent/core/protocol_registry.py b/aries_cloudagent/core/protocol_registry.py index c322fad2ad..b64b789efd 100644 --- a/aries_cloudagent/core/protocol_registry.py +++ b/aries_cloudagent/core/protocol_registry.py @@ -1,47 +1,89 @@ """Handle registration and publication of supported protocols.""" +from dataclasses import dataclass import logging -import re -from typing import Mapping, Sequence +from typing import Any, Dict, Mapping, Optional, Sequence, Union from ..config.injection_context import InjectionContext -from ..utils.classloader import ClassLoader +from ..utils.classloader import ClassLoader, DeferLoad +from ..messaging.message_type import MessageType, MessageVersion, ProtocolIdentifier from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError LOGGER = logging.getLogger(__name__) +@dataclass +class VersionDefinition: + """Version definition.""" + + min: MessageVersion + current: MessageVersion + + @classmethod + def from_dict(cls, data: dict) -> "VersionDefinition": + """Create a version definition from a dict.""" + return cls( + min=MessageVersion(data["major_version"], data["minimum_minor_version"]), + current=MessageVersion( + data["major_version"], data["current_minor_version"] + ), + ) + + +@dataclass +class ProtocolDefinition: + """Protocol metadata used to register and resolve message types.""" + + ident: ProtocolIdentifier + min: MessageVersion + current: MessageVersion + controller: Optional[str] = None + + @property + def minor_versions_supported(self) -> bool: + """Accessor for whether minor versions are supported.""" + return bool(self.current.minor >= 1 and self.current.minor >= self.min.minor) + + def __post_init__(self): + """Post-init hook.""" + if self.min.major != self.current.major: + raise ProtocolDefinitionValidationError( + f"Major version mismatch: {self.min.major} != {self.current.major}" + ) + if self.min.minor > self.current.minor: + raise ProtocolDefinitionValidationError( + f"Minimum minor version greater than current minor version: " + f"{self.min.minor} > {self.current.minor}" + ) + + class ProtocolRegistry: """Protocol registry for indexing message families.""" def __init__(self): """Initialize a `ProtocolRegistry` instance.""" + + self._definitions: Dict[str, ProtocolDefinition] = {} + self._type_to_message_cls: Dict[str, Union[DeferLoad, type]] = {} + + # Mapping[protocol identifier, controller module path] self._controllers = {} - self._typemap = {} - self._versionmap = {} @property def protocols(self) -> Sequence[str]: """Accessor for a list of all message protocols.""" - prots = set() - for message_type in self._typemap.keys(): - pos = message_type.rfind("/") - if pos > 0: - family = message_type[:pos] - prots.add(family) - return prots + return [ + str(definition.ident.with_version((definition.min.major, minor))) + for definition in self._definitions.values() + for minor in range(definition.min.minor, definition.current.minor + 1) + ] @property def message_types(self) -> Sequence[str]: """Accessor for a list of all message types.""" - return tuple(self._typemap.keys()) - - @property - def controllers(self) -> Mapping[str, str]: - """Accessor for a list of all protocol controller functions.""" - return self._controllers.copy() + return tuple(self._type_to_message_cls.keys()) def protocols_matching_query(self, query: str) -> Sequence[str]: """Return a list of message protocols matching a query string.""" @@ -58,129 +100,65 @@ def protocols_matching_query(self, query: str) -> Sequence[str]: result = (query,) return result or () - def parse_type_string(self, message_type): - """Parse message type string and return dict with info.""" - tokens = message_type.split("/") - protocol_name = tokens[-3] - version_string = tokens[-2] - message_name = tokens[-1] - - version_string_tokens = version_string.split(".") - assert len(version_string_tokens) == 2 - - return { - "protocol_name": protocol_name, - "message_name": message_name, - "major_version": int(version_string_tokens[0]), - "minor_version": int(version_string_tokens[1]), - } - - def create_msg_types_for_minor_version(self, typesets, version_definition): - """Return mapping of message type to module path for minor versions. + def register_message_types( + self, + typeset: Mapping[str, Union[str, type]], + version_definition: Optional[Union[dict[str, Any], VersionDefinition]] = None, + ): + """Add new supported message types. Args: typesets: Mappings of message types to register version_definition: Optional version definition dict - Returns: - Typesets mapping - """ - updated_typeset = {} - curr_minor_version = version_definition["current_minor_version"] - min_minor_version = version_definition["minimum_minor_version"] - major_version = version_definition["major_version"] - if curr_minor_version >= min_minor_version: - for version_index in range(min_minor_version, curr_minor_version + 1): - to_check = f"{str(major_version)}.{str(version_index)}" - updated_typeset.update( - self._get_updated_typeset_dict(typesets, to_check, updated_typeset) - ) - else: - raise ProtocolDefinitionValidationError( - "min_minor_version is greater than curr_minor_version for the" - f" following typeset: {str(typesets)}" - ) - return (updated_typeset,) + if version_definition is not None and isinstance(version_definition, dict): + version_definition = VersionDefinition.from_dict(version_definition) + + definitions_to_add = {} + type_to_message_cls_to_add = {} + + for message_type, message_cls in typeset.items(): + parsed = MessageType.from_str(message_type) + protocol = ProtocolIdentifier.from_message_type(parsed) + if protocol.stem in definitions_to_add: + definition = definitions_to_add[protocol.stem] + elif protocol.stem in self._definitions: + definition = self._definitions[protocol.stem] + else: + if version_definition: + definition = ProtocolDefinition( + ident=protocol, + min=version_definition.min, + current=version_definition.current, + ) + else: + definition = ProtocolDefinition( + ident=protocol, + min=protocol.version, + current=protocol.version, + ) - def _get_updated_typeset_dict(self, typesets, to_check, updated_typeset) -> dict: - for typeset in typesets: - for msg_type_string, module_path in typeset.items(): - updated_msg_type_string = re.sub( - r"(\d+\.)?(\*|\d+)", to_check, msg_type_string - ) - updated_typeset[updated_msg_type_string] = module_path - return updated_typeset - - def _message_type_check_for_minor_verssion(self, version_definition) -> bool: - if not version_definition: - return False - curr_minor_version = version_definition["current_minor_version"] - min_minor_version = version_definition["minimum_minor_version"] - return bool(curr_minor_version >= 1 and curr_minor_version >= min_minor_version) - - def _create_and_register_updated_typesets(self, typesets, version_definition): - updated_typesets = self.create_msg_types_for_minor_version( - typesets, version_definition - ) - update_flag = False - for typeset in updated_typesets: - if typeset: - self._typemap.update(typeset) - update_flag = True - if update_flag: - return updated_typesets - else: - return None - - def _update_version_map(self, message_type_string, module_path, version_definition): - parsed_type_string = self.parse_type_string(message_type_string) - - if version_definition["major_version"] not in self._versionmap: - self._versionmap[version_definition["major_version"]] = [] - - self._versionmap[version_definition["major_version"]].append( - { - "parsed_type_string": parsed_type_string, - "version_definition": version_definition, - "message_module": module_path, - } - ) + definitions_to_add[protocol.stem] = definition - def register_message_types(self, *typesets, version_definition=None): - """Add new supported message types. + if isinstance(message_cls, str): + message_cls = DeferLoad(message_cls) - Args: - typesets: Mappings of message types to register - version_definition: Optional version definition dict - - """ + type_to_message_cls_to_add[message_type] = message_cls - # Maintain support for versionless protocol modules - updated_typesets = None - minor_versions_supported = self._message_type_check_for_minor_verssion( - version_definition - ) - if not minor_versions_supported: - for typeset in typesets: - self._typemap.update(typeset) - - # Track versioned modules for version routing - if version_definition: - # create updated typesets for minor versions and register them - if minor_versions_supported: - updated_typesets = self._create_and_register_updated_typesets( - typesets, version_definition - ) - if updated_typesets: - typesets = updated_typesets - for typeset in typesets: - for message_type_string, module_path in typeset.items(): - self._update_version_map( - message_type_string, module_path, version_definition + if definition.minor_versions_supported: + for minor_version in range( + definition.min.minor, definition.current.minor + 1 + ): + updated_type = parsed.with_version( + (parsed.version.major, minor_version) ) + type_to_message_cls_to_add[str(updated_type)] = message_cls + + self._type_to_message_cls.update(type_to_message_cls_to_add) + self._definitions.update(definitions_to_add) - def register_controllers(self, *controller_sets, version_definition=None): + def register_controllers(self, *controller_sets): """Add new controllers. Args: @@ -190,7 +168,9 @@ def register_controllers(self, *controller_sets, version_definition=None): for controlset in controller_sets: self._controllers.update(controlset) - def resolve_message_class(self, message_type: str) -> type: + def resolve_message_class( + self, message_type: str + ) -> Optional[Union[DeferLoad, type]]: """Resolve a message_type to a message class. Given a message type identifier, this method @@ -203,46 +183,25 @@ def resolve_message_class(self, message_type: str) -> type: The resolved message class """ + if (message_cls := self._type_to_message_cls.get(message_type)) is not None: + return message_cls + + parsed = MessageType.from_str(message_type) + protocol = ProtocolIdentifier.from_message_type(parsed) + if definition := self._definitions.get(protocol.stem): + if parsed.version.minor < definition.min.minor: + raise ProtocolMinorVersionNotSupported( + f"Minimum supported minor version is {definition.min.minor}." + f" Received {parsed.version.minor}." + ) - # Try and retrieve from direct mapping - msg_cls = self._typemap.get(message_type) - if isinstance(msg_cls, str): - return ClassLoader.load_class(msg_cls) - - # Support registered modules (not path as string) - elif msg_cls: - return msg_cls - - # Try and route via min/maj version matching - if not msg_cls: - parsed_type_string = self.parse_type_string(message_type) - major_version = parsed_type_string["major_version"] - - version_supported_protos = self._versionmap.get(major_version) - if not version_supported_protos: - return None - - for proto in version_supported_protos: - if ( - proto["parsed_type_string"]["protocol_name"] - == parsed_type_string["protocol_name"] - and proto["parsed_type_string"]["message_name"] - == parsed_type_string["message_name"] - ): - if ( - parsed_type_string["minor_version"] - < proto["version_definition"]["minimum_minor_version"] - ): - raise ProtocolMinorVersionNotSupported( - "Minimum supported minor version is " - + f"{proto['version_definition']['minimum_minor_version']}." - + f" Received {parsed_type_string['minor_version']}." - ) - - if isinstance(proto["message_module"], str): - return ClassLoader.load_class(msg_cls) - elif proto["message_module"]: - return proto["message_module"] + # This code will only be reached if the received minor version is greater + # than our current supported version. All directly supported minor + # versions would be returned previously. + message_type = str(parsed.with_version(definition.current)) + + if (message_cls := self._type_to_message_cls.get(message_type)) is not None: + return message_cls return None @@ -252,7 +211,7 @@ async def prepare_disclosed( """Call controllers and return publicly supported message families and roles.""" published = [] for protocol in protocols: - result = {"pid": protocol} + result: Dict[str, Any] = {"pid": protocol} if protocol in self._controllers: ctl_cls = self._controllers[protocol] if isinstance(ctl_cls, str): diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index 0644c6d3fa..5d7ef0ac5f 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -60,7 +60,7 @@ class StubAgentMessage(AgentMessage): class Meta: handler_class = "StubAgentMessageHandler" schema_class = "StubAgentMessageSchema" - message_type = "proto-name/1.1/message-type" + message_type = "doc/proto-name/1.1/message-type" class StubAgentMessageSchema(AgentMessageSchema): @@ -78,7 +78,7 @@ class StubV1_2AgentMessage(AgentMessage): class Meta: handler_class = "StubV1_2AgentMessageHandler" schema_class = "StubV1_2AgentMessageSchema" - message_type = "proto-name/1.2/message-type" + message_type = "doc/proto-name/1.2/message-type" class StubV1_2AgentMessageSchema(AgentMessageSchema): @@ -113,15 +113,7 @@ async def test_dispatch(self): StubAgentMessageHandler, "handle", autospec=True ) as handler_mock, mock.patch.object( test_module, "BaseConnectionManager", autospec=True - ) as conn_mgr_mock, mock.patch.object( - test_module, - "get_version_from_message_type", - mock.MagicMock(return_value="1.1"), - ), mock.patch.object( - test_module, - "validate_get_response_version", - mock.CoroutineMock(return_value=("1.1", None)), - ): + ) as conn_mgr_mock: conn_mgr_mock.return_value = mock.MagicMock( find_inbound_connection=mock.CoroutineMock( return_value=mock.MagicMock(connection_id="dummy") @@ -162,15 +154,7 @@ async def test_dispatch_versioned_message(self): with mock.patch.object( StubAgentMessageHandler, "handle", autospec=True - ) as handler_mock, mock.patch.object( - test_module, - "get_version_from_message_type", - mock.MagicMock(return_value="1.1"), - ), mock.patch.object( - test_module, - "validate_get_response_version", - mock.CoroutineMock(return_value=("1.1", None)), - ): + ) as handler_mock: await dispatcher.queue_message( dispatcher.profile, make_inbound(message), rcv.send ) @@ -200,7 +184,7 @@ async def test_dispatch_versioned_message_no_message_class(self): dispatcher = test_module.Dispatcher(profile) await dispatcher.setup() rcv = Receiver() - message = {"@type": "proto-name/1.1/no-such-message-type"} + message = {"@type": "doc/proto-name/1.1/no-such-message-type"} with mock.patch.object( StubAgentMessageHandler, "handle", autospec=True @@ -234,7 +218,7 @@ async def test_dispatch_versioned_message_message_class_deserialize_x(self): dispatcher = test_module.Dispatcher(profile) await dispatcher.setup() rcv = Receiver() - message = {"@type": "proto-name/1.1/no-such-message-type"} + message = {"@type": "doc/proto-name/1.1/no-such-message-type"} with mock.patch.object( StubAgentMessageHandler, "handle", autospec=True @@ -281,15 +265,7 @@ async def test_dispatch_versioned_message_handle_greater_succeeds(self): with mock.patch.object( StubAgentMessageHandler, "handle", autospec=True - ) as handler_mock, mock.patch.object( - test_module, - "get_version_from_message_type", - mock.MagicMock(return_value="1.1"), - ), mock.patch.object( - test_module, - "validate_get_response_version", - mock.CoroutineMock(return_value=("1.1", None)), - ): + ) as handler_mock: await dispatcher.queue_message( dispatcher.profile, make_inbound(message), rcv.send ) @@ -341,22 +317,17 @@ async def test_bad_message_dispatch_parse_x(self): await dispatcher.setup() rcv = Receiver() bad_messages = ["not even a dict", {"bad": "message"}] - with mock.patch.object( - test_module, "get_version_from_message_type", mock.MagicMock() - ), mock.patch.object( - test_module, "validate_get_response_version", mock.CoroutineMock() - ): - for bad in bad_messages: - await dispatcher.queue_message( - dispatcher.profile, make_inbound(bad), rcv.send - ) - await dispatcher.task_queue - assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) - payload = json.loads(rcv.messages[0][1].payload) - assert payload["@type"] == DIDCommPrefix.qualify_current( - ProblemReport.Meta.message_type - ) - rcv.messages.clear() + for bad in bad_messages: + await dispatcher.queue_message( + dispatcher.profile, make_inbound(bad), rcv.send + ) + await dispatcher.task_queue + assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) + payload = json.loads(rcv.messages[0][1].payload) + assert payload["@type"] == DIDCommPrefix.qualify_current( + ProblemReport.Meta.message_type + ) + rcv.messages.clear() async def test_bad_message_dispatch_problem_report_x(self): profile = make_profile() @@ -593,91 +564,3 @@ def _smaller_scope(): with pytest.deprecated_call(): with self.assertRaises(RuntimeError): await responder.send_webhook("test", {}) - - # async def test_dispatch_version_with_degraded_features(self): - # profile = make_profile() - # registry = profile.inject(ProtocolRegistry) - # registry.register_message_types( - # { - # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage - # for pfx in DIDCommPrefix - # } - # ) - # dispatcher = test_module.Dispatcher(profile) - # await dispatcher.setup() - # rcv = Receiver() - # message = { - # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) - # } - - # with mock.patch.object( - # test_module, - # "get_version_from_message_type", - # mock.CoroutineMock(return_value="1.1"), - # ), mock.patch.object( - # test_module, - # "validate_get_response_version", - # mock.CoroutineMock(return_value=("1.1", "fields-ignored-due-to-version-mismatch")), - # ): - # await dispatcher.queue_message( - # dispatcher.profile, make_inbound(message), rcv.send - # ) - - # async def test_dispatch_fields_ignored_due_to_version_mismatch(self): - # profile = make_profile() - # registry = profile.inject(ProtocolRegistry) - # registry.register_message_types( - # { - # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage - # for pfx in DIDCommPrefix - # } - # ) - # dispatcher = test_module.Dispatcher(profile) - # await dispatcher.setup() - # rcv = Receiver() - # message = { - # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) - # } - - # with mock.patch.object( - # test_module, - # "get_version_from_message_type", - # mock.CoroutineMock(return_value="1.1"), - # ), mock.patch.object( - # test_module, - # "validate_get_response_version", - # mock.CoroutineMock(return_value=("1.1", "version-with-degraded-features")), - # ): - # await dispatcher.queue_message( - # dispatcher.profile, make_inbound(message), rcv.send - # ) - - # async def test_dispatch_version_not_supported(self): - # profile = make_profile() - # registry = profile.inject(ProtocolRegistry) - # registry.register_message_types( - # { - # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage - # for pfx in DIDCommPrefix - # } - # ) - # dispatcher = test_module.Dispatcher(profile) - # await dispatcher.setup() - # rcv = Receiver() - # message = { - # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) - # } - - # with mock.patch.object( - # test_module, - # "get_version_from_message_type", - # mock.CoroutineMock(return_value="1.1"), - # ), mock.patch.object( - # test_module, - # "validate_get_response_version", - # mock.CoroutineMock(return_value=("1.1", "version-not-supported")), - # ): - # with self.assertRaises(test_module.MessageParseError): - # await dispatcher.queue_message( - # dispatcher.profile, make_inbound(message), rcv.send - # ) diff --git a/aries_cloudagent/core/tests/test_protocol_registry.py b/aries_cloudagent/core/tests/test_protocol_registry.py index d5383bba1d..72c78a2c6f 100644 --- a/aries_cloudagent/core/tests/test_protocol_registry.py +++ b/aries_cloudagent/core/tests/test_protocol_registry.py @@ -2,7 +2,7 @@ from unittest import IsolatedAsyncioTestCase from ...config.injection_context import InjectionContext -from ...utils.classloader import ClassLoader +from ...utils.classloader import ClassLoader, DeferLoad from ..protocol_registry import ProtocolRegistry @@ -10,11 +10,11 @@ class TestProtocolRegistry(IsolatedAsyncioTestCase): no_type_message = {"a": "b"} unknown_type_message = {"@type": 1} - test_message_type = "PROTOCOL/MESSAGE" - test_protocol = "PROTOCOL" - test_protocol_queries = ["*", "PROTOCOL", "PROTO*"] + test_message_type = "doc/protocol/1.0/message" + test_protocol = "doc/protocol/1.0" + test_protocol_queries = ["*", "doc/protocol/1.0", "doc/proto*"] test_protocol_queries_fail = ["", "nomatch", "nomatch*"] - test_message_handler = "fake_handler" + test_message_cls = "fake_msg_cls" test_controller = "fake_controller" def setUp(self): @@ -22,7 +22,7 @@ def setUp(self): def test_protocols(self): self.registry.register_message_types( - {self.test_message_type: self.test_message_handler} + {self.test_message_type: self.test_message_cls} ) self.registry.register_controllers( {self.test_message_type: self.test_controller} @@ -30,13 +30,10 @@ def test_protocols(self): assert list(self.registry.message_types) == [self.test_message_type] assert list(self.registry.protocols) == [self.test_protocol] - assert self.registry.controllers == { - self.test_message_type: self.test_controller - } def test_message_type_query(self): self.registry.register_message_types( - {self.test_message_type: self.test_message_handler} + {self.test_message_type: self.test_message_cls} ) for q in self.test_protocol_queries: matches = self.registry.protocols_matching_query(q) @@ -45,165 +42,127 @@ def test_message_type_query(self): matches = self.registry.protocols_matching_query(q) assert matches == () - def test_create_msg_types_for_minor_version(self): + def test_registration_with_minor_version(self): MSG_PATH = "aries_cloudagent.protocols.introduction.v0_1.messages" - test_typesets = ( - { - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-invitation": f"{MSG_PATH}.invitation.Invitation", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", - "https://didcom.org/introduction-service/1.0/fake-forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", - "https://didcom.org/introduction-service/1.0/fake-invitation": f"{MSG_PATH}.invitation.Invitation", - "https://didcom.org/introduction-service/1.0/fake-invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", - }, - ) + test_typesets = { + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-invitation": f"{MSG_PATH}.invitation.Invitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", + "https://didcom.org/introduction-service/1.0/fake-forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", + "https://didcom.org/introduction-service/1.0/fake-invitation": f"{MSG_PATH}.invitation.Invitation", + "https://didcom.org/introduction-service/1.0/fake-invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", + } test_version_def = { "current_minor_version": 0, "major_version": 1, "minimum_minor_version": 0, "path": "v0_1", } - updated_typesets = self.registry.create_msg_types_for_minor_version( - test_typesets, test_version_def - ) - updated_typeset = updated_typesets[0] + self.registry.register_message_types(test_typesets, test_version_def) assert ( "https://didcom.org/introduction-service/1.0/fake-forward-invitation" - in updated_typeset + in self.registry.message_types ) assert ( "https://didcom.org/introduction-service/1.0/fake-invitation" - in updated_typeset + in self.registry.message_types ) assert ( "https://didcom.org/introduction-service/1.0/fake-invitation-request" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-forward-invitation" - in updated_typeset + in self.registry.message_types ) - def test_introduction_create_msg_types_for_minor_version(self): - MSG_PATH = "aries_cloudagent.protocols.introduction.v0_1.messages" - test_typesets = ( - { - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation": f"{MSG_PATH}.invitation.Invitation", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/forward-invitation": f"{MSG_PATH}.invitation_messages.forward_invitation.ForwardInvitation", - "https://didcom.org/introduction-service/0.1/invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", - "https://didcom.org/introduction-service/0.1/invitation": f"{MSG_PATH}.invitation.Invitation", - "https://didcom.org/introduction-service/0.1/forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", - }, - ) + def test_register_msg_types_for_multiple_minor_versions(self): + MSG_PATH = "aries_cloudagent.protocols.out_of_band.v1_0.messages" + test_typesets = { + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/invitation": f"{MSG_PATH}.invitation.Invitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse": f"{MSG_PATH}.reuse.HandshakeReuse", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse-accepted": f"{MSG_PATH}.reuse_accept.HandshakeReuseAccept", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/problem_report": f"{MSG_PATH}.problem_report.OOBProblemReport", + "https://didcom.org/out-of-band/1.1/invitation": f"{MSG_PATH}.invitation.Invitation", + "https://didcom.org/out-of-band/1.1/handshake-reuse": f"{MSG_PATH}.reuse.HandshakeReuse", + "https://didcom.org/out-of-band/1.1/handshake-reuse-accepted": f"{MSG_PATH}.reuse_accept.HandshakeReuseAccept", + "https://didcom.org/out-of-band/1.1/problem_report": f"{MSG_PATH}.problem_report.OOBProblemReport", + } test_version_def = { "current_minor_version": 1, - "major_version": 0, - "minimum_minor_version": 1, + "major_version": 1, + "minimum_minor_version": 0, "path": "v0_1", } - updated_typesets = self.registry.create_msg_types_for_minor_version( - test_typesets, test_version_def - ) - updated_typeset = updated_typesets[0] + self.registry.register_message_types(test_typesets, test_version_def) assert ( - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation-request" - in updated_typeset + "https://didcom.org/out-of-band/1.0/invitation" + in self.registry.message_types ) assert ( - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation" - in updated_typeset + "https://didcom.org/out-of-band/1.0/handshake-reuse" + in self.registry.message_types ) assert ( - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/forward-invitation" - in updated_typeset + "https://didcom.org/out-of-band/1.0/handshake-reuse-accepted" + in self.registry.message_types ) assert ( - "https://didcom.org/introduction-service/0.1/invitation-request" - in updated_typeset + "https://didcom.org/out-of-band/1.0/problem_report" + in self.registry.message_types ) assert ( - "https://didcom.org/introduction-service/0.1/invitation" in updated_typeset + "https://didcom.org/out-of-band/1.1/invitation" + in self.registry.message_types ) assert ( - "https://didcom.org/introduction-service/0.1/forward-invitation" - in updated_typeset - ) - - def test_oob_create_msg_types_for_minor_version(self): - MSG_PATH = "aries_cloudagent.protocols.out_of_band.v1_0.messages" - test_typesets = ( - { - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/invitation": f"{MSG_PATH}.invitation.Invitation", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse": f"{MSG_PATH}.reuse.HandshakeReuse", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse-accepted": f"{MSG_PATH}.reuse_accept.HandshakeReuseAccept", - "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/problem_report": f"{MSG_PATH}.problem_report.OOBProblemReport", - "https://didcom.org/out-of-band/1.1/invitation": f"{MSG_PATH}.invitation.Invitation", - "https://didcom.org/out-of-band/1.1/handshake-reuse": f"{MSG_PATH}.reuse.HandshakeReuse", - "https://didcom.org/out-of-band/1.1/handshake-reuse-accepted": f"{MSG_PATH}.reuse_accept.HandshakeReuseAccept", - "https://didcom.org/out-of-band/1.1/problem_report": f"{MSG_PATH}.problem_report.OOBProblemReport", - }, - ) - test_version_def = { - "current_minor_version": 1, - "major_version": 1, - "minimum_minor_version": 0, - "path": "v0_1", - } - updated_typesets = self.registry.create_msg_types_for_minor_version( - test_typesets, test_version_def + "https://didcom.org/out-of-band/1.1/handshake-reuse" + in self.registry.message_types ) - updated_typeset = updated_typesets[0] - assert "https://didcom.org/out-of-band/1.0/invitation" in updated_typeset - assert "https://didcom.org/out-of-band/1.0/handshake-reuse" in updated_typeset assert ( - "https://didcom.org/out-of-band/1.0/handshake-reuse-accepted" - in updated_typeset + "https://didcom.org/out-of-band/1.1/handshake-reuse-accepted" + in self.registry.message_types ) - assert "https://didcom.org/out-of-band/1.0/problem_report" in updated_typeset - assert "https://didcom.org/out-of-band/1.1/invitation" in updated_typeset - assert "https://didcom.org/out-of-band/1.1/handshake-reuse" in updated_typeset assert ( - "https://didcom.org/out-of-band/1.1/handshake-reuse-accepted" - in updated_typeset + "https://didcom.org/out-of-band/1.1/problem_report" + in self.registry.message_types ) - assert "https://didcom.org/out-of-band/1.1/problem_report" in updated_typeset assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/invitation" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/handshake-reuse" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/handshake-reuse-accepted" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/problem_report" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/invitation" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse-accepted" - in updated_typeset + in self.registry.message_types ) assert ( "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/problem_report" - in updated_typeset + in self.registry.message_types ) async def test_disclosed(self): self.registry.register_message_types( - {self.test_message_type: self.test_message_handler} + {self.test_message_type: self.test_message_cls} ) mocked = mock.MagicMock() mocked.return_value.check_access = mock.CoroutineMock() @@ -222,7 +181,7 @@ async def test_disclosed(self): async def test_disclosed_str(self): self.registry.register_message_types( - {self.test_message_type: self.test_message_handler} + {self.test_message_type: self.test_message_cls} ) self.registry.register_controllers({self.test_protocol: "mock-class-name"}) protocols = [self.test_protocol] @@ -244,24 +203,20 @@ async def check_access(self, context): def test_resolve_message_class_str(self): self.registry.register_message_types( - {self.test_message_type: self.test_message_handler} + {self.test_message_type: self.test_message_cls} ) - mock_class = mock.MagicMock() - with mock.patch.object( - ClassLoader, "load_class", mock.MagicMock() - ) as load_class: - load_class.return_value = mock_class - result = self.registry.resolve_message_class(self.test_message_type) - assert result == mock_class + result = self.registry.resolve_message_class(self.test_message_type) + assert isinstance(result, DeferLoad) + assert result._cls_path == self.test_message_cls def test_resolve_message_class_no_major_version_support(self): - result = self.registry.resolve_message_class("proto/1.2/hello") + result = self.registry.resolve_message_class("doc/proto/1.2/hello") assert result is None def test_resolve_message_load_class_str(self): - message_type_a = "proto/1.2/aaa" + message_type_a = "doc/proto/1.2/aaa" self.registry.register_message_types( - {message_type_a: self.test_message_handler}, + {message_type_a: self.test_message_cls}, version_definition={ "major_version": 1, "minimum_minor_version": 0, @@ -269,18 +224,14 @@ def test_resolve_message_load_class_str(self): "path": "v1_2", }, ) - mock_class = mock.MagicMock() - with mock.patch.object( - ClassLoader, "load_class", mock.MagicMock() - ) as load_class: - load_class.side_effect = [mock_class, mock_class] - result = self.registry.resolve_message_class("proto/1.1/aaa") - assert result == mock_class + result = self.registry.resolve_message_class("doc/proto/1.1/aaa") + assert isinstance(result, DeferLoad) + assert result._cls_path == self.test_message_cls def test_resolve_message_load_class_none(self): - message_type_a = "proto/1.2/aaa" + message_type_a = "doc/proto/1.2/aaa" self.registry.register_message_types( - {message_type_a: self.test_message_handler}, + {message_type_a: self.test_message_cls}, version_definition={ "major_version": 1, "minimum_minor_version": 0, @@ -288,13 +239,8 @@ def test_resolve_message_load_class_none(self): "path": "v1_2", }, ) - mock_class = mock.MagicMock() - with mock.patch.object( - ClassLoader, "load_class", mock.MagicMock() - ) as load_class: - load_class.side_effect = [mock_class, mock_class] - result = self.registry.resolve_message_class("proto/1.2/bbb") - assert result is None + result = self.registry.resolve_message_class("doc/proto/1.2/bbb") + assert result is None def test_repr(self): assert isinstance(repr(self.registry), str) diff --git a/aries_cloudagent/core/tests/test_util.py b/aries_cloudagent/core/tests/test_util.py deleted file mode 100644 index 5c28b6baa8..0000000000 --- a/aries_cloudagent/core/tests/test_util.py +++ /dev/null @@ -1,82 +0,0 @@ -from unittest import IsolatedAsyncioTestCase - -from ...cache.base import BaseCache -from ...cache.in_memory import InMemoryCache -from ...core.in_memory import InMemoryProfile -from ...core.profile import Profile -from ...protocols.didcomm_prefix import DIDCommPrefix -from ...protocols.introduction.v0_1.messages.invitation import Invitation -from ...protocols.out_of_band.v1_0.messages.reuse import HandshakeReuse - -from .. import util as test_module - - -def make_profile() -> Profile: - profile = InMemoryProfile.test_profile() - profile.context.injector.bind_instance(BaseCache, InMemoryCache()) - return profile - - -class TestUtils(IsolatedAsyncioTestCase): - async def test_validate_get_response_version(self): - profile = make_profile() - (resp_version, warning) = await test_module.validate_get_response_version( - profile, "1.1", HandshakeReuse - ) - assert resp_version == "1.1" - assert not warning - - # cached - (resp_version, warning) = await test_module.validate_get_response_version( - profile, "1.1", HandshakeReuse - ) - assert resp_version == "1.1" - assert not warning - - (resp_version, warning) = await test_module.validate_get_response_version( - profile, "1.0", HandshakeReuse - ) - assert resp_version == "1.0" - assert warning == test_module.WARNING_DEGRADED_FEATURES - - (resp_version, warning) = await test_module.validate_get_response_version( - profile, "1.2", HandshakeReuse - ) - assert resp_version == "1.1" - assert warning == test_module.WARNING_VERSION_MISMATCH - - with self.assertRaises(test_module.ProtocolMinorVersionNotSupported): - (resp_version, warning) = await test_module.validate_get_response_version( - profile, "0.0", Invitation - ) - - with self.assertRaises(Exception): - (resp_version, warning) = await test_module.validate_get_response_version( - profile, "1.0", Invitation - ) - - def test_get_version_from_message_type(self): - assert ( - test_module.get_version_from_message_type( - DIDCommPrefix.qualify_current("out-of-band/1.1/handshake-reuse") - ) - == "1.1" - ) - - def test_get_version_from_message(self): - assert test_module.get_version_from_message(HandshakeReuse()) == "1.1" - - async def test_get_proto_default_version_from_msg_class(self): - profile = make_profile() - assert ( - await test_module.get_proto_default_version_from_msg_class( - profile, HandshakeReuse - ) - ) == "1.1" - - def test_get_proto_default_version(self): - assert ( - test_module.get_proto_default_version( - "aries_cloudagent.protocols.out_of_band.definition" - ) - ) == "1.1" diff --git a/aries_cloudagent/core/util.py b/aries_cloudagent/core/util.py index 2b35da9987..865da41f3b 100644 --- a/aries_cloudagent/core/util.py +++ b/aries_cloudagent/core/util.py @@ -1,18 +1,7 @@ """Core utilities and constants.""" -import inspect -import os import re -from typing import Optional, Tuple - -from ..cache.base import BaseCache -from ..core.profile import Profile -from ..messaging.agent_message import AgentMessage -from ..utils.classloader import ClassLoader - -from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError - CORE_EVENT_PREFIX = "acapy::core::" STARTUP_EVENT_TOPIC = CORE_EVENT_PREFIX + "startup" STARTUP_EVENT_PATTERN = re.compile(f"^{STARTUP_EVENT_TOPIC}?$") @@ -21,159 +10,3 @@ WARNING_DEGRADED_FEATURES = "version-with-degraded-features" WARNING_VERSION_MISMATCH = "fields-ignored-due-to-version-mismatch" WARNING_VERSION_NOT_SUPPORTED = "version-not-supported" - - -async def validate_get_response_version( - profile: Profile, rec_version: str, msg_class: type -) -> Tuple[str, Optional[str]]: - """Return a tuple with version to respond with and warnings. - - Process received version and protocol version definition, - returns the tuple. - - Args: - profile: Profile - rec_version: received version from message - msg_class: type - - Returns: - Tuple with response version and any warnings - - """ - resp_version = rec_version - warning = None - version_string_tokens = rec_version.split(".") - rec_major_version = int(version_string_tokens[0]) - rec_minor_version = int(version_string_tokens[1]) - version_definition = await get_version_def_from_msg_class( - profile, msg_class, rec_major_version - ) - proto_major_version = int(version_definition["major_version"]) - proto_curr_minor_version = int(version_definition["current_minor_version"]) - proto_min_minor_version = int(version_definition["minimum_minor_version"]) - if rec_minor_version < proto_min_minor_version: - warning = WARNING_VERSION_NOT_SUPPORTED - elif ( - rec_minor_version >= proto_min_minor_version - and rec_minor_version < proto_curr_minor_version - ): - warning = WARNING_DEGRADED_FEATURES - elif rec_minor_version > proto_curr_minor_version: - warning = WARNING_VERSION_MISMATCH - if proto_major_version == rec_major_version: - if ( - proto_min_minor_version <= rec_minor_version - and proto_curr_minor_version >= rec_minor_version - ): - resp_version = f"{str(proto_major_version)}.{str(rec_minor_version)}" - elif rec_minor_version > proto_curr_minor_version: - resp_version = f"{str(proto_major_version)}.{str(proto_curr_minor_version)}" - elif rec_minor_version < proto_min_minor_version: - raise ProtocolMinorVersionNotSupported( - "Minimum supported minor version is " - + f"{proto_min_minor_version}." - + f" Received {rec_minor_version}." - ) - else: - raise ProtocolMinorVersionNotSupported( - f"Supported major version {proto_major_version}" - " is not same as received major version" - f" {rec_major_version}." - ) - return (resp_version, warning) - - -def get_version_from_message_type(msg_type: str) -> str: - """Return version from provided message_type.""" - return (re.search(r"(\d+\.)?(\*|\d+)", msg_type)).group() - - -def get_version_from_message(msg: AgentMessage) -> str: - """Return version from provided AgentMessage.""" - msg_type = msg._type - return get_version_from_message_type(msg_type) - - -async def get_proto_default_version_from_msg_class( - profile: Profile, msg_class: type, major_version: int = 1 -) -> str: - """Return default protocol version from version_definition.""" - version_definition = await get_version_def_from_msg_class( - profile, msg_class, major_version - ) - return _get_default_version_from_version_def(version_definition) - - -def get_proto_default_version(def_path: str, major_version: int = 1) -> str: - """Return default protocol version from version_definition.""" - version_definition = _get_version_def_from_path(def_path, major_version) - return _get_default_version_from_version_def(version_definition) - - -def _resolve_definition(search_path: str, msg_class: type) -> str: - try: - path = os.path.normpath(inspect.getfile(msg_class)) - path = search_path + path.rsplit(search_path, 1)[1] - version = (re.search(r"v(\d+\_)?(\*|\d+)", path)).group() - path = path.split(version, 1)[0] - definition_path = (path.replace("/", ".")) + "definition" - if ClassLoader.load_module(definition_path): - return definition_path - except Exception: - # we expect some exceptions resolving paths - pass - - -def _get_path_from_msg_class(msg_class: type) -> str: - search_paths = ["aries_cloudagent", msg_class.__module__.split(".", 1)[0]] - if os.getenv("ACAPY_HOME"): - search_paths.insert(os.getenv("ACAPY_HOME"), 0) - - definition_path = None - searches = 0 - while not definition_path and searches < len(search_paths): - definition_path = _resolve_definition(search_paths[searches], msg_class) - searches = searches + 1 - # we could throw an exception here, - return definition_path - - -def _get_version_def_from_path(definition_path: str, major_version: int = 1): - version_definition = None - definition = ClassLoader.load_module(definition_path) - for protocol_version in definition.versions: - if major_version == protocol_version["major_version"]: - version_definition = protocol_version - break - return version_definition - - -def _get_default_version_from_version_def(version_definition) -> str: - default_major_version = version_definition["major_version"] - default_minor_version = version_definition["current_minor_version"] - return f"{default_major_version}.{default_minor_version}" - - -async def get_version_def_from_msg_class( - profile: Profile, msg_class: type, major_version: int = 1 -): - """Return version_definition of a protocol from msg_class.""" - cache = profile.inject_or(BaseCache) - version_definition = None - if cache: - version_definition = await cache.get( - f"version_definition::{str(msg_class).lower()}" - ) - if version_definition: - return version_definition - definition_path = _get_path_from_msg_class(msg_class) - version_definition = _get_version_def_from_path(definition_path, major_version) - if not version_definition: - raise ProtocolDefinitionValidationError( - f"Unable to load protocol version_definition for {str(msg_class)}" - ) - if cache: - await cache.set( - f"version_definition::{str(msg_class).lower()}", version_definition - ) - return version_definition diff --git a/aries_cloudagent/messaging/agent_message.py b/aries_cloudagent/messaging/agent_message.py index 08e6ec94e9..ccb5895183 100644 --- a/aries_cloudagent/messaging/agent_message.py +++ b/aries_cloudagent/messaging/agent_message.py @@ -1,9 +1,8 @@ """Agent message base class and schema.""" -import uuid from collections import OrderedDict -from re import sub from typing import Mapping, Optional, Text, Union +import uuid from marshmallow import ( EXCLUDE, @@ -21,7 +20,7 @@ from .decorators.base import BaseDecoratorSet from .decorators.default import DecoratorSet from .decorators.service_decorator import ServiceDecorator -from .decorators.signature_decorator import SignatureDecorator # TODO deprecated +from .decorators.signature_decorator import SignatureDecorator from .decorators.thread_decorator import ThreadDecorator from .decorators.trace_decorator import ( TRACE_LOG_TARGET, @@ -29,6 +28,7 @@ TraceDecorator, TraceReport, ) +from .message_type import MessageTypeStr from .models.base import ( BaseModel, BaseModelError, @@ -86,17 +86,13 @@ def __init__( self.__class__.__name__ ) ) - if _type: - self._message_type = _type - elif _version: - self._message_type = self.get_updated_msg_type(_version) - else: - self._message_type = self.Meta.message_type - # Not required for now - # if not self.Meta.handler_class: - # raise TypeError( - # "Can't instantiate abstract class {} with no handler_class".format( - # self.__class__.__name__)) + + self._message_type = MessageTypeStr( + DIDCommPrefix.qualify_current(_type or self.Meta.message_type) + ) + + if _version: + self.assign_version(_version) @classmethod def _get_handler_class(cls): @@ -119,19 +115,14 @@ def Handler(self) -> type: return self._get_handler_class() @property - def _type(self) -> str: + def _type(self) -> MessageTypeStr: """Accessor for the message type identifier. Returns: Current DIDComm prefix, slash, message type defined on `Meta.message_type` """ - return DIDCommPrefix.qualify_current(self._message_type) - - @_type.setter - def _type(self, msg_type: str): - """Set the message type identifier.""" - self._message_type = msg_type + return self._message_type @property def _id(self) -> str: @@ -153,15 +144,35 @@ def _decorators(self) -> BaseDecoratorSet: """Fetch the message's decorator set.""" return self._message_decorators + @property + def _version(self) -> str: + """Accessor for the message version.""" + return str(self._type.version) + + def assign_version_from(self, msg: "AgentMessage"): + """Copy version information from a previous message. + + Args: + msg: The received message containing version information to copy + + """ + if msg: + self.assign_version(msg._version) + + def assign_version(self, version: str): + """Assign a specific version. + + Args: + version: The version to assign + + """ + self._message_type = self._message_type.with_version(version) + @_decorators.setter def _decorators(self, value: BaseDecoratorSet): """Fetch the message's decorator set.""" self._message_decorators = value - def get_updated_msg_type(self, version: str) -> str: - """Update version to Meta.message_type.""" - return sub(r"(\d+\.)?(\*|\d+)", version, self.Meta.message_type) - def get_signature(self, field_name: str) -> SignatureDecorator: """Get the signature for a named field. @@ -294,7 +305,7 @@ def _thread(self) -> ThreadDecorator: return self._decorators.get("thread") @_thread.setter - def _thread(self, val: Union[ThreadDecorator, dict]): + def _thread(self, val: Union[ThreadDecorator, dict, None]): """Setter for the message's thread decorator. Args: @@ -324,7 +335,7 @@ def assign_thread_from(self, msg: "AgentMessage"): pthid = thread and thread.pthid self.assign_thread_id(thid, pthid) - def assign_thread_id(self, thid: str, pthid: str = None): + def assign_thread_id(self, thid: str, pthid: Optional[str] = None): """Assign a specific thread ID. Args: diff --git a/aries_cloudagent/messaging/decorators/attach_decorator.py b/aries_cloudagent/messaging/decorators/attach_decorator.py index cdfa31a696..1ed1ec123f 100644 --- a/aries_cloudagent/messaging/decorators/attach_decorator.py +++ b/aries_cloudagent/messaging/decorators/attach_decorator.py @@ -364,7 +364,6 @@ def build_protected(verkey: str): json.dumps( { "alg": "EdDSA", - "kid": did_key(verkey), "jwk": { "kty": "OKP", "crv": "Ed25519", diff --git a/aries_cloudagent/messaging/decorators/thread_decorator.py b/aries_cloudagent/messaging/decorators/thread_decorator.py index cbae0684cc..32fb93170e 100644 --- a/aries_cloudagent/messaging/decorators/thread_decorator.py +++ b/aries_cloudagent/messaging/decorators/thread_decorator.py @@ -4,7 +4,7 @@ context from previous messages. """ -from typing import Mapping +from typing import Mapping, Optional from marshmallow import EXCLUDE, fields @@ -23,10 +23,10 @@ class Meta: def __init__( self, *, - thid: str = None, - pthid: str = None, - sender_order: int = None, - received_orders: Mapping = None, + thid: Optional[str] = None, + pthid: Optional[str] = None, + sender_order: Optional[int] = None, + received_orders: Optional[Mapping] = None, ): """Initialize a ThreadDecorator instance. diff --git a/aries_cloudagent/messaging/message_type.py b/aries_cloudagent/messaging/message_type.py new file mode 100644 index 0000000000..9e387eae6f --- /dev/null +++ b/aries_cloudagent/messaging/message_type.py @@ -0,0 +1,232 @@ +"""Utilities for working with Message Types and Versions.""" + +from dataclasses import dataclass +from functools import lru_cache +import re +from typing import ClassVar, Pattern, Tuple, Union + + +@dataclass +class MessageVersion: + """Message type version.""" + + PATTERN: ClassVar[Pattern] = re.compile(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)$") + + major: int + minor: int + + @classmethod + @lru_cache + def from_str(cls, value: str): + """Parse a version string.""" + if match := cls.PATTERN.match(value): + return cls( + int(match.group(1)), + int(match.group(2)), + ) + + raise ValueError(f"Invalid version: {value}") + + def __gt__(self, other: "MessageVersion") -> bool: + """Test whether this version is greater than another.""" + if self.major != other.major: + return self.major > other.major + return self.minor > other.minor + + def __eq__(self, other: object) -> bool: + """Equality comparison.""" + if not isinstance(other, MessageVersion): + return False + + return self.major == other.major and self.minor == other.minor + + def __lt__(self, other: "MessageVersion") -> bool: + """Test whether this version is less than another.""" + if self.major != other.major: + return self.major < other.major + return self.minor < other.minor + + def __str__(self) -> str: + """Return the version as a string.""" + return f"{self.major}.{self.minor}" + + def __hash__(self) -> int: + """Return a hash of the version.""" + return hash((self.major, self.minor)) + + def compatible(self, other: "MessageVersion") -> bool: + """Test whether this version is compatible with another.""" + if self == other: + return True + return self.major == other.major and self.minor <= other.minor + + +@dataclass +class ProtocolIdentifier: + """Protocol identifier.""" + + PATTERN: ClassVar[Pattern] = re.compile(r"^(.*?)/([a-z0-9._-]+)/(\d[^/]*)$") + FROM_MESSAGE_TYPE_PATTERN: ClassVar[Pattern] = re.compile( + r"^(.*?)/([a-z0-9._-]+)/(\d[^/]*).*$" + ) + + doc_uri: str + protocol: str + version: MessageVersion + + @classmethod + @lru_cache + def from_str(cls, value: str) -> "ProtocolIdentifier": + """Parse a protocol identifier string.""" + if match := cls.PATTERN.match(value): + return cls( + doc_uri=match.group(1), + protocol=match.group(2), + version=MessageVersion.from_str(match.group(3)), + ) + raise ValueError(f"Invalid protocol identifier: {value}") + + @classmethod + @lru_cache + def from_message_type( + cls, message_type: Union[str, "MessageType"] + ) -> "ProtocolIdentifier": + """Create a protocol identifier from a message type.""" + if isinstance(message_type, str): + if match := cls.FROM_MESSAGE_TYPE_PATTERN.match(message_type): + return cls( + doc_uri=match.group(1), + protocol=match.group(2), + version=MessageVersion.from_str(match.group(3)), + ) + + raise ValueError(f"Invalid protocol identifier: {message_type}") + elif isinstance(message_type, MessageType): + return cls( + message_type.doc_uri, message_type.protocol, message_type.version + ) + else: + raise TypeError(f"Invalid message type: {message_type}") + + def __str__(self) -> str: + """Return the protocol identifier as a string.""" + return f"{self.doc_uri}/{self.protocol}/{self.version}" + + @property + def stem(self) -> str: + """Return the protocol stem, including doc_uri, protocol, and major version.""" + return f"{self.doc_uri}/{self.protocol}/{self.version.major}" + + def with_version( + self, version: Union[str, MessageVersion, Tuple[int, int]] + ) -> "ProtocolIdentifier": + """Return a new protocol identifier with the specified version.""" + if isinstance(version, str): + version = MessageVersion.from_str(version) + + if isinstance(version, tuple): + version = MessageVersion(*version) + + return ProtocolIdentifier( + doc_uri=self.doc_uri, + protocol=self.protocol, + version=version, + ) + + +@dataclass +class MessageType: + """Message type.""" + + PATTERN: ClassVar[Pattern] = re.compile( + r"^(.*?)/([a-z0-9._-]+)/(\d[^/]*)/([a-z0-9._-]+)$" + ) + + doc_uri: str + protocol: str + version: MessageVersion + name: str + + @classmethod + @lru_cache + def from_str(cls, value: str): + """Parse a message type string.""" + if match := cls.PATTERN.match(value): + return cls( + doc_uri=match.group(1), + protocol=match.group(2), + version=MessageVersion.from_str(match.group(3)), + name=match.group(4), + ) + + raise ValueError(f"Invalid message type: {value}") + + def __str__(self) -> str: + """Return the message type as a string.""" + return f"{self.doc_uri}/{self.protocol}/{self.version}/{self.name}" + + def with_version( + self, version: Union[str, MessageVersion, Tuple[int, int]] + ) -> "MessageType": + """Return a new message type with the specified version.""" + if isinstance(version, str): + version = MessageVersion.from_str(version) + + if isinstance(version, tuple): + version = MessageVersion(*version) + + return MessageType( + doc_uri=self.doc_uri, + protocol=self.protocol, + version=version, + name=self.name, + ) + + def __hash__(self) -> int: + """Return a hash of the message type.""" + return hash((self.doc_uri, self.protocol, self.version, self.name)) + + +class MessageTypeStr(str): + """Message type string.""" + + def __init__(self, value: str): + """Initialize the message type string.""" + super().__init__() + self._parsed = MessageType.from_str(value) + + @property + def parsed(self) -> MessageType: + """Return the parsed message type.""" + return self._parsed + + @property + def doc_uri(self) -> str: + """Return the message type document URI.""" + return self._parsed.doc_uri + + @property + def protocol(self) -> str: + """Return the message type protocol.""" + return self._parsed.protocol + + @property + def version(self) -> MessageVersion: + """Return the message type version.""" + return self._parsed.version + + @property + def name(self) -> str: + """Return the message type name.""" + return self._parsed.name + + @property + def protocol_identifier(self) -> ProtocolIdentifier: + """Return the message type protocol identifier.""" + return ProtocolIdentifier.from_message_type(self._parsed) + + def with_version( + self, version: Union[str, MessageVersion, Tuple[int, int]] + ) -> "MessageTypeStr": + """Return a new message type string with the specified version.""" + return MessageTypeStr(str(self._parsed.with_version(version))) diff --git a/aries_cloudagent/messaging/responder.py b/aries_cloudagent/messaging/responder.py index 12f0d3256f..6c41393ee2 100644 --- a/aries_cloudagent/messaging/responder.py +++ b/aries_cloudagent/messaging/responder.py @@ -19,6 +19,9 @@ from .base_message import BaseMessage SKIP_ACTIVE_CONN_CHECK_MSG_TYPES = [ + "didexchange/1.1/request", + "didexchange/1.1/response", + "didexchange/1.1/problem_report", "didexchange/1.0/request", "didexchange/1.0/response", "didexchange/1.0/problem_report", diff --git a/aries_cloudagent/messaging/tests/test_agent_message.py b/aries_cloudagent/messaging/tests/test_agent_message.py index 12105fcbfe..bf3ee6fa85 100644 --- a/aries_cloudagent/messaging/tests/test_agent_message.py +++ b/aries_cloudagent/messaging/tests/test_agent_message.py @@ -20,7 +20,7 @@ class Meta: handler_class = None schema_class = "SignedAgentMessageSchema" - message_type = "signed-agent-message" + message_type = "doc/protocol/1.0/signed-agent-message" def __init__(self, value: str = None, **kwargs): super().__init__(**kwargs) @@ -45,7 +45,7 @@ class Meta: """Meta data""" schema_class = AgentMessageSchema - message_type = "basic-message" + message_type = "doc/protocol/1.0/basic-message" class TestAgentMessage(IsolatedAsyncioTestCase): diff --git a/aries_cloudagent/messaging/util.py b/aries_cloudagent/messaging/util.py index acf1c9d962..7c924f6ef3 100644 --- a/aries_cloudagent/messaging/util.py +++ b/aries_cloudagent/messaging/util.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone from hashlib import sha256 from math import floor -from typing import Any, Union +from typing import Any, Dict, List, Optional, Union LOGGER = logging.getLogger(__name__) @@ -147,3 +147,17 @@ def canon(raw_attr_name: str) -> str: if raw_attr_name: # do not dereference None, and "" is already canonical return raw_attr_name.replace(" ", "").lower() return raw_attr_name + + +def get_proto_default_version( + versions: List[Dict[str, Any]], major_version: int = 1 +) -> Optional[str]: + """Return default protocol version from version definition list.""" + + for version in versions: + if major_version == version["major_version"]: + default_major_version = version["major_version"] + default_minor_version = version["current_minor_version"] + return f"{default_major_version}.{default_minor_version}" + + return None diff --git a/aries_cloudagent/protocols/connections/v1_0/routes.py b/aries_cloudagent/protocols/connections/v1_0/routes.py index 4e6ff949ba..f8460ac82a 100644 --- a/aries_cloudagent/protocols/connections/v1_0/routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/routes.py @@ -284,12 +284,10 @@ class ConnectionsListQueryStringSchema(OpenAPISchema): ) connection_protocol = fields.Str( required=False, - validate=validate.OneOf( - [proto.aries_protocol for proto in ConnRecord.Protocol] - ), + validate=validate.OneOf(ConnRecord.SUPPORTED_PROTOCOLS), metadata={ "description": "Connection protocol used", - "example": ConnRecord.Protocol.RFC_0160.aries_protocol, + "example": "connections/1.0", }, ) invitation_msg_id = fields.Str( diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py index 3328c9adbe..d880f17e59 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py @@ -32,7 +32,7 @@ async def test_connections_list(self): self.request.query = { "invitation_id": "dummy", # exercise tag filter assignment "their_role": ConnRecord.Role.REQUESTER.rfc160, - "connection_protocol": ConnRecord.Protocol.RFC_0160.aries_protocol, + "connection_protocol": "connections/1.0", "invitation_key": "some-invitation-key", "their_public_did": "a_public_did", "invitation_msg_id": "dummy_msg", @@ -99,7 +99,7 @@ async def test_connections_list(self): }, post_filter_positive={ "their_role": list(ConnRecord.Role.REQUESTER.value), - "connection_protocol": ConnRecord.Protocol.RFC_0160.aries_protocol, + "connection_protocol": "connections/1.0", }, alt=True, ) diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py index f217637673..15c3cf80af 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py @@ -63,7 +63,7 @@ async def test_hangup(self): mock_conn_record.delete_record.assert_called_once() mock_send.assert_called_once() assert ( - msg._type == DIDCommPrefix.OLD.value + "/" + test_message_types.HANGUP + msg._type == DIDCommPrefix.NEW.value + "/" + test_message_types.HANGUP ) async def test_receive_hangup(self): @@ -83,7 +83,7 @@ async def test_rotate_my_did(self): msg = await self.manager.rotate_my_did(mock_conn_record, test_to_did) mock_send.assert_called_once() assert ( - msg._type == DIDCommPrefix.OLD.value + "/" + test_message_types.ROTATE + msg._type == DIDCommPrefix.NEW.value + "/" + test_message_types.ROTATE ) async def test_receive_rotate(self): diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py index 7cb9658b86..a4b68f08c4 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py @@ -5,6 +5,7 @@ from .....protocols.didcomm_prefix import DIDCommPrefix from .....storage.error import StorageNotFoundError from .....tests import mock +from ..messages import Hangup, Rotate from .. import message_types as test_message_types from .. import routes as test_module from ..tests import MockConnRecord, test_conn_id @@ -15,20 +16,12 @@ def generate_mock_hangup_message(): - schema = test_module.HangupMessageSchema() - msg = schema.load({}) - - msg._id = "test-message-id" - msg._type = test_message_types.HANGUP + msg = Hangup(_id="test-message-id") return msg def generate_mock_rotate_message(): - schema = test_module.RotateMesageSchema() - msg = schema.load(test_valid_rotate_request) - - msg._id = "test-message-id" - msg._type = test_message_types.ROTATE + msg = Rotate(_id="test-message-id", **test_valid_rotate_request) return msg @@ -73,7 +66,7 @@ async def test_rotate(self, *_): mock_response.assert_called_once_with( { "@id": "test-message-id", - "@type": DIDCommPrefix.OLD.value + "/" + test_message_types.ROTATE, + "@type": DIDCommPrefix.NEW.value + "/" + test_message_types.ROTATE, **test_valid_rotate_request, } ) @@ -101,7 +94,7 @@ async def test_hangup(self, *_): mock_response.assert_called_once_with( { "@id": "test-message-id", - "@type": DIDCommPrefix.OLD.value + "/" + test_message_types.HANGUP, + "@type": DIDCommPrefix.NEW.value + "/" + test_message_types.HANGUP, } ) diff --git a/aries_cloudagent/protocols/didcomm_prefix.py b/aries_cloudagent/protocols/didcomm_prefix.py index dea36b7752..dcd30076e5 100644 --- a/aries_cloudagent/protocols/didcomm_prefix.py +++ b/aries_cloudagent/protocols/didcomm_prefix.py @@ -3,8 +3,6 @@ import re from enum import Enum -from os import environ -from typing import Mapping QUALIFIED = re.compile(r"^[a-zA-Z\-\+]+:.+") @@ -21,17 +19,7 @@ class DIDCommPrefix(Enum): NEW = "https://didcomm.org" OLD = "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec" - @staticmethod - def set(settings: Mapping): - """Set current DIDComm prefix value in environment.""" - - environ["DIDCOMM_PREFIX"] = ( - DIDCommPrefix.NEW.value - if settings.get("emit_new_didcomm_prefix") - else DIDCommPrefix.OLD.value - ) - - def qualify(self, msg_type: str = None) -> str: + def qualify(self, msg_type: str) -> str: """Qualify input message type with prefix and separator.""" return qualify(msg_type, self.value) @@ -43,10 +31,13 @@ def qualify_all(cls, messages: dict) -> dict: return {qualify(k, pfx.value): v for pfx in cls for k, v in messages.items()} @staticmethod - def qualify_current(slug: str = None) -> str: - """Qualify input slug with prefix currently in effect and separator.""" + def qualify_current(slug: str) -> str: + """Qualify input slug with prefix currently in effect and separator. + + This method now will always use the new prefix. + """ - return qualify(slug, environ.get("DIDCOMM_PREFIX", DIDCommPrefix.OLD.value)) + return qualify(slug, DIDCommPrefix.NEW.value) @staticmethod def unqualify(qual: str) -> str: diff --git a/aries_cloudagent/protocols/didexchange/definition.py b/aries_cloudagent/protocols/didexchange/definition.py index 62bddef6f5..13c1f8a8ef 100644 --- a/aries_cloudagent/protocols/didexchange/definition.py +++ b/aries_cloudagent/protocols/didexchange/definition.py @@ -4,7 +4,7 @@ { "major_version": 1, "minimum_minor_version": 0, - "current_minor_version": 0, + "current_minor_version": 1, "path": "v1_0", } ] diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index c65b730ccb..b10f326772 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union from did_peer_4 import LONG_PATTERN, long_to_short @@ -24,12 +24,13 @@ from ....wallet.error import WalletError from ....wallet.key_type import ED25519 from ...coordinate_mediation.v1_0.manager import MediationManager +from ...coordinate_mediation.v1_0.models.mediation_record import MediationRecord from ...discovery.v2_0.manager import V20DiscoveryMgr from ...out_of_band.v1_0.messages.invitation import ( InvitationMessage as OOBInvitationMessage, ) from ...out_of_band.v1_0.messages.service import Service as OOBService -from .message_types import ARIES_PROTOCOL as DIDX_PROTO +from .message_types import ARIES_PROTOCOL as DIDEX_1_1, DIDEX_1_0 from .messages.complete import DIDXComplete from .messages.problem_report import DIDXProblemReport, ProblemReportReason from .messages.request import DIDXRequest @@ -40,6 +41,13 @@ class DIDXManagerError(BaseError): """Connection error.""" +class LegacyHandlingFallback(DIDXManagerError): + """Raised when a request cannot be completed using updated semantics. + + Triggers falling back to legacy handling. + """ + + class DIDXManager(BaseConnectionManager): """Class for managing connections under RFC 23 (DID exchange).""" @@ -70,6 +78,7 @@ async def receive_invitation( auto_accept: Optional[bool] = None, alias: Optional[str] = None, mediation_id: Optional[str] = None, + protocol: Optional[str] = None, ) -> ConnRecord: # leave in didexchange as it uses a responder: not out-of-band """Create a new connection record to track a received invitation. @@ -108,6 +117,9 @@ async def receive_invitation( ) else ConnRecord.ACCEPT_MANUAL ) + protocol = protocol or DIDEX_1_0 + if protocol not in ConnRecord.SUPPORTED_PROTOCOLS: + raise DIDXManagerError(f"Unexpected protocol: {protocol}") service_item = invitation.services[0] # Create connection record @@ -124,7 +136,7 @@ async def receive_invitation( accept=accept, alias=alias, their_public_did=their_public_did, - connection_protocol=DIDX_PROTO, + connection_protocol=protocol, ) async with self.profile.session() as session: @@ -170,14 +182,15 @@ async def receive_invitation( async def create_request_implicit( self, their_public_did: str, - my_label: str = None, - my_endpoint: str = None, - mediation_id: str = None, + my_label: Optional[str] = None, + my_endpoint: Optional[str] = None, + mediation_id: Optional[str] = None, use_public_did: bool = False, - alias: str = None, - goal_code: str = None, - goal: str = None, + alias: Optional[str] = None, + goal_code: Optional[str] = None, + goal: Optional[str] = None, auto_accept: bool = False, + protocol: Optional[str] = None, ) -> ConnRecord: """Create and send a request against a public DID only (no explicit invitation). @@ -228,6 +241,7 @@ async def create_request_implicit( and self.profile.settings.get("debug.auto_accept_requests") ) ) + protocol = protocol or DIDEX_1_0 conn_rec = ConnRecord( my_did=( my_public_info.did if my_public_info else None @@ -239,7 +253,7 @@ async def create_request_implicit( invitation_msg_id=None, alias=alias, their_public_did=their_public_did, - connection_protocol=DIDX_PROTO, + connection_protocol=protocol, accept=ConnRecord.ACCEPT_AUTO if auto_accept else ConnRecord.ACCEPT_MANUAL, ) request = await self.create_request( # saves and updates conn_rec @@ -249,7 +263,6 @@ async def create_request_implicit( mediation_id=mediation_id, goal_code=goal_code, goal=goal, - use_public_did=bool(my_public_info), ) conn_rec.request_id = request._id conn_rec.state = ConnRecord.State.REQUEST.rfc160 @@ -269,7 +282,6 @@ async def create_request( mediation_id: Optional[str] = None, goal_code: Optional[str] = None, goal: Optional[str] = None, - use_public_did: bool = False, ) -> DIDXRequest: """Create a new connection request for a previously-received invitation. @@ -281,8 +293,6 @@ async def create_request( service endpoint goal_code: Optional self-attested code for sharing intent of connection goal: Optional self-attested string for sharing intent of connection - use_public_did: Flag whether to use public DID and omit DID Doc - attachment on request Returns: A new `DIDXRequest` message to send to the other agent @@ -295,9 +305,6 @@ async def create_request( or_default=True, ) - my_info = None - - # Create connection request message if my_endpoint: my_endpoints = [my_endpoint] else: @@ -307,52 +314,9 @@ async def create_request( my_endpoints.append(default_endpoint) my_endpoints.extend(self.profile.settings.get("additional_endpoints", [])) - emit_did_peer_4 = self.profile.settings.get("emit_did_peer_4") - emit_did_peer_2 = self.profile.settings.get("emit_did_peer_2") - if emit_did_peer_2 and emit_did_peer_4: - self._logger.warning( - "emit_did_peer_2 and emit_did_peer_4 both set, \ - using did:peer:4" - ) - - if conn_rec.my_did: - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.get_local_did(conn_rec.my_did) - elif emit_did_peer_4: - my_info = await self.create_did_peer_4(my_endpoints, mediation_records) - conn_rec.my_did = my_info.did - elif emit_did_peer_2: - my_info = await self.create_did_peer_2(my_endpoints, mediation_records) - conn_rec.my_did = my_info.did - else: - # Create new DID for connection - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.create_local_did( - method=SOV, - key_type=ED25519, - ) - conn_rec.my_did = my_info.did - - if use_public_did or emit_did_peer_2 or emit_did_peer_4: - # Omit DID Doc attachment if we're using a public DID - did_doc = None - attach = None - did = conn_rec.my_did - if not did.startswith("did:"): - did = f"did:sov:{did}" - else: - did_doc = await self.create_did_document( - my_info, - my_endpoints, - mediation_records=mediation_records, - ) - attach = AttachDecorator.data_base64(did_doc.serialize()) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - await attach.data.sign(my_info.verkey, wallet) - did = conn_rec.my_did + if not my_label: + my_label = self.profile.settings.get("default_label") + assert my_label did_url = None if conn_rec.their_public_did is not None: @@ -362,16 +326,37 @@ async def create_request( pthid = conn_rec.invitation_msg_id or did_url - if not my_label: - my_label = self.profile.settings.get("default_label") + if conn_rec.connection_protocol == DIDEX_1_0: + did, attach = await self._legacy_did_with_attached_doc( + conn_rec, my_endpoints, mediation_records + ) + else: + emit_did_peer_2 = bool(self.profile.settings.get("emit_did_peer_2")) + emit_did_peer_4 = bool(self.profile.settings.get("emit_did_peer_4")) + try: + did, attach = await self._qualified_did_with_fallback( + conn_rec, + my_endpoints, + mediation_records, + emit_did_peer_2, + emit_did_peer_4, + ) + except LegacyHandlingFallback: + did, attach = await self._legacy_did_with_attached_doc( + conn_rec, my_endpoints, mediation_records + ) request = DIDXRequest( label=my_label, did=did, did_doc_attach=attach, - goal_code=goal_code, goal=goal, + goal_code=goal_code, ) + + if conn_rec.connection_protocol == DIDEX_1_0: + request.assign_version("1.0") + request.assign_thread_id(thid=request._id, pthid=pthid) # Update connection state @@ -387,12 +372,109 @@ async def create_request( return request + async def _qualified_did_with_fallback( + self, + conn_rec: ConnRecord, + my_endpoints: Sequence[str], + mediation_records: List[MediationRecord], + emit_did_peer_2: bool, + emit_did_peer_4: bool, + signing_key: Optional[str] = None, + ) -> Tuple[str, Optional[AttachDecorator]]: + """Create DID Exchange request using a qualified DID. + + Fall back to unqualified DID if settings don't cause did:peer emission. + """ + if emit_did_peer_2 and emit_did_peer_4: + self._logger.warning( + "emit_did_peer_2 and emit_did_peer_4 both set, \ + using did:peer:4" + ) + + if conn_rec.my_did: # DID should be public or qualified + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did(conn_rec.my_did) + + posture = DIDPosture.get(my_info.metadata) + if posture not in ( + DIDPosture.PUBLIC, + DIDPosture.POSTED, + ) and not my_info.did.startswith("did:"): + raise LegacyHandlingFallback( + "DID has been previously set and not public or qualified" + ) + elif emit_did_peer_4: + my_info = await self.create_did_peer_4(my_endpoints, mediation_records) + conn_rec.my_did = my_info.did + elif emit_did_peer_2: + my_info = await self.create_did_peer_2(my_endpoints, mediation_records) + conn_rec.my_did = my_info.did + else: + raise LegacyHandlingFallback( + "Use of qualified DIDs not set according to settings" + ) + + did = conn_rec.my_did + assert did, "DID must be set on connection record" + if not did.startswith("did:"): + did = f"did:sov:{did}" + + attach = None + if signing_key: + attach = AttachDecorator.data_base64_string(did) + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await attach.data.sign(signing_key, wallet) + + return did, attach + + async def _legacy_did_with_attached_doc( + self, + conn_rec: ConnRecord, + my_endpoints: Sequence[str], + mediation_records: List[MediationRecord], + invitation_key: Optional[str] = None, + ) -> Tuple[str, Optional[AttachDecorator]]: + """Create a DID Exchange request using an unqualified DID.""" + if conn_rec.my_did: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.get_local_did(conn_rec.my_did) + else: + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + my_info = await wallet.create_local_did( + method=SOV, + key_type=ED25519, + ) + conn_rec.my_did = my_info.did + + posture = DIDPosture.get(my_info.metadata) + if posture in ( + DIDPosture.PUBLIC, + DIDPosture.POSTED, + ): + return my_info.did, None + + did_doc = await self.create_did_document( + my_info, + my_endpoints, + mediation_records=mediation_records, + ) + attach = AttachDecorator.data_base64(did_doc.serialize()) + + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + await attach.data.sign(invitation_key or my_info.verkey, wallet) + + return my_info.did, attach + async def receive_request( self, request: DIDXRequest, recipient_did: str, recipient_verkey: Optional[str] = None, - my_endpoint: Optional[str] = None, alias: Optional[str] = None, auto_accept_implicit: Optional[bool] = None, ) -> ConnRecord: @@ -415,121 +497,124 @@ async def receive_request( settings=self.profile.settings, ) - conn_rec = None - connection_key = None - my_info = None - - # Determine what key will need to sign the response - if recipient_verkey: # peer DID - connection_key = recipient_verkey - try: - async with self.profile.session() as session: - conn_rec = await ConnRecord.retrieve_by_invitation_key( - session=session, - invitation_key=connection_key, - their_role=ConnRecord.Role.REQUESTER.rfc23, - ) - except StorageNotFoundError: - if recipient_verkey: - raise DIDXManagerError( - "No explicit invitation found for pairwise connection " - f"in state {ConnRecord.State.INVITATION.rfc23}: " - "a prior connection request may have updated the connection state" - ) + if recipient_verkey: + conn_rec = await self._receive_request_pairwise_did( + request, recipient_verkey, alias + ) else: - if not self.profile.settings.get("public_invites"): - raise DIDXManagerError( - "Public invitations are not enabled: connection request refused" - ) + conn_rec = await self._receive_request_public_did( + request, recipient_did, alias, auto_accept_implicit + ) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.get_local_did(recipient_did) - if DIDPosture.get(my_info.metadata) not in ( - DIDPosture.PUBLIC, - DIDPosture.POSTED, - ): - raise DIDXManagerError(f"Request DID {recipient_did} is not public") - connection_key = my_info.verkey + # Clean associated oob record if not needed anymore + oob_processor = self.profile.inject(OobMessageProcessor) + await oob_processor.clean_finished_oob_record(self.profile, request) + return conn_rec + + async def _receive_request_pairwise_did( + self, + request: DIDXRequest, + recipient_verkey: str, + alias: Optional[str] = None, + ) -> ConnRecord: + """Receive a DID Exchange request against a pairwise (not public) DID.""" + try: async with self.profile.session() as session: - conn_rec = await ConnRecord.retrieve_by_invitation_msg_id( + conn_rec = await ConnRecord.retrieve_by_invitation_key( session=session, - invitation_msg_id=request._thread.pthid, + invitation_key=recipient_verkey, their_role=ConnRecord.Role.REQUESTER.rfc23, ) + except StorageNotFoundError: + raise DIDXManagerError( + "No explicit invitation found for pairwise connection " + f"in state {ConnRecord.State.INVITATION.rfc23}: " + "a prior connection request may have updated the connection state" + ) - save_reason = None - if conn_rec: # invitation was explicit - connection_key = conn_rec.invitation_key - if conn_rec.is_multiuse_invitation: - new_conn_rec = ConnRecord( - invitation_key=connection_key, - state=ConnRecord.State.INIT.rfc160, - accept=conn_rec.accept, - their_role=conn_rec.their_role, - connection_protocol=DIDX_PROTO, - ) - async with self.profile.session() as session: - # TODO: Suppress the event that gets emitted here? - await new_conn_rec.save( - session, - reason="Created new connection record from multi-use invitation", - ) + if conn_rec.is_multiuse_invitation: + conn_rec = await self._derive_new_conn_from_multiuse_invitation(conn_rec) - # Transfer metadata from multi-use to new connection - # Must come after save so there's an ID to associate with metadata - async with self.profile.session() as session: - for key, value in ( - await conn_rec.metadata_get_all(session) - ).items(): - await new_conn_rec.metadata_set(session, key, value) + conn_rec.their_label = request.label + if alias: + conn_rec.alias = alias + conn_rec.their_did = request.did + conn_rec.state = ConnRecord.State.REQUEST.rfc160 + conn_rec.request_id = request._id + conn_rec.connection_protocol = self._handshake_protocol_to_use(request) - conn_rec = new_conn_rec + # TODO move to common method or add to transaction? + await self._extract_and_record_did_doc_info(request) - # request DID doc describes requester DID - if request.did_doc_attach and request.did_doc_attach.data: - self._logger.debug("Received DID Doc attachment in request") - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - conn_did_doc = await self.verify_diddoc(wallet, request.did_doc_attach) - await self.store_did_document(conn_did_doc) + async with self.profile.transaction() as txn: + # Attach the connection request so it can be found and responded to + await conn_rec.save( + txn, reason="Received connection request from invitation" + ) + await conn_rec.attach_request(txn, request) + await txn.commit() - # Special case: legacy DIDs were unqualified in request, qualified in doc - if request.did and not request.did.startswith("did:"): - did_to_check = f"did:sov:{request.did}" - else: - did_to_check = request.did + return conn_rec - if did_to_check != conn_did_doc["id"]: - raise DIDXManagerError( - ( - f"Connection DID {request.did} does not match " - f"DID Doc id {conn_did_doc['id']}" - ), - error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value, - ) - else: - if request.did is None: - raise DIDXManagerError("No DID in request") + def _handshake_protocol_to_use(self, request: DIDXRequest): + """Determine the connection protocol to use based on the request. - self._logger.debug( - "No DID Doc attachment in request; doc will be resolved from DID" + If we support it, we'll send it. If we don't, we'll try didexchage/1.1. + """ + protocol = f"{request._type.protocol}/{request._type.version}" + if protocol in ConnRecord.SUPPORTED_PROTOCOLS: + return protocol + + return DIDEX_1_1 + + async def _receive_request_public_did( + self, + request: DIDXRequest, + recipient_did: str, + alias: Optional[str] = None, + auto_accept_implicit: Optional[bool] = None, + ) -> ConnRecord: + """Receive a DID Exchange request against a public DID.""" + if not self.profile.settings.get("public_invites"): + raise DIDXManagerError( + "Public invitations are not enabled: connection request refused" ) - await self.record_keys_for_resolvable_did(request.did) - if conn_rec: # request is against explicit invitation - auto_accept = ( - conn_rec.accept == ConnRecord.ACCEPT_AUTO - ) # null=manual; oob-manager calculated at conn rec creation + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + public_did_info = await wallet.get_local_did(recipient_did) + + if DIDPosture.get(public_did_info.metadata) not in ( + DIDPosture.PUBLIC, + DIDPosture.POSTED, + ): + raise DIDXManagerError(f"Request DID {recipient_did} is not public") + if request._thread.pthid: + # Invitation was explicit + async with self.profile.session() as session: + conn_rec = await ConnRecord.retrieve_by_invitation_msg_id( + session=session, + invitation_msg_id=request._thread.pthid, + their_role=ConnRecord.Role.REQUESTER.rfc23, + ) + else: + # Invitation was implicit + conn_rec = None + + if conn_rec and conn_rec.is_multiuse_invitation: + conn_rec = await self._derive_new_conn_from_multiuse_invitation(conn_rec) + + save_reason = None + if conn_rec: conn_rec.their_label = request.label if alias: conn_rec.alias = alias conn_rec.their_did = request.did conn_rec.state = ConnRecord.State.REQUEST.rfc160 conn_rec.request_id = request._id - save_reason = "Received connection request from invitation" + save_reason = "Received connection request from invitation to public DID" else: # request is against implicit invitation on public DID if not self.profile.settings.get("requests_through_public_did"): @@ -554,26 +639,93 @@ async def receive_request( their_label=request.label, alias=alias, their_role=ConnRecord.Role.REQUESTER.rfc23, - invitation_key=connection_key, + invitation_key=public_did_info.verkey, invitation_msg_id=None, request_id=request._id, state=ConnRecord.State.REQUEST.rfc160, - connection_protocol=DIDX_PROTO, ) save_reason = "Received connection request from public DID" + conn_rec.connection_protocol = self._handshake_protocol_to_use(request) + + # TODO move to common method or add to transaction? + await self._extract_and_record_did_doc_info(request) + async with self.profile.transaction() as txn: # Attach the connection request so it can be found and responded to await conn_rec.save(txn, reason=save_reason) await conn_rec.attach_request(txn, request) await txn.commit() - # Clean associated oob record if not needed anymore - oob_processor = self.profile.inject(OobMessageProcessor) - await oob_processor.clean_finished_oob_record(self.profile, request) - return conn_rec + async def _extract_and_record_did_doc_info(self, request: DIDXRequest): + """Extract and record DID Document information from the DID Exchange request. + + Extracting this info enables us to correlate messages from these keys back to a + connection when we later receive inbound messages. + """ + if request.did_doc_attach and request.did_doc_attach.data: + self._logger.debug("Received DID Doc attachment in request") + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + conn_did_doc = await self.verify_diddoc(wallet, request.did_doc_attach) + await self.store_did_document(conn_did_doc) + + # Special case: legacy DIDs were unqualified in request, qualified in doc + if request.did and not request.did.startswith("did:"): + did_to_check = f"did:sov:{request.did}" + else: + did_to_check = request.did + + if did_to_check != conn_did_doc["id"]: + raise DIDXManagerError( + ( + f"Connection DID {request.did} does not match " + f"DID Doc id {conn_did_doc['id']}" + ), + error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value, + ) + else: + if request.did is None: + raise DIDXManagerError("No DID in request") + + self._logger.debug( + "No DID Doc attachment in request; doc will be resolved from DID" + ) + await self.record_keys_for_resolvable_did(request.did) + + async def _derive_new_conn_from_multiuse_invitation( + self, conn_rec: ConnRecord + ) -> ConnRecord: + """Derive a new connection record from a multi-use invitation. + + Multi-use invitations are tracked using a connection record. When a connection + is formed through a multi-use invitation conn rec, a new record for the resulting + connection is required. The original multi-use invitation record is retained + until deleted by the user. + """ + new_conn_rec = ConnRecord( + invitation_key=conn_rec.invitation_key, + state=ConnRecord.State.INIT.rfc160, + accept=conn_rec.accept, + their_role=conn_rec.their_role, + ) + async with self.profile.session() as session: + # TODO: Suppress the event that gets emitted here? + await new_conn_rec.save( + session, + reason="Created new connection record from multi-use invitation", + ) + + # Transfer metadata from multi-use to new connection + # Must come after save so there's an ID to associate with metadata + async with self.profile.session() as session: + for key, value in (await conn_rec.metadata_get_all(session)).items(): + await new_conn_rec.metadata_set(session, key, value) + + return new_conn_rec + async def create_response( self, conn_rec: ConnRecord, @@ -619,75 +771,55 @@ async def create_response( my_endpoints.append(default_endpoint) my_endpoints.extend(self.profile.settings.get("additional_endpoints", [])) - respond_with_did_peer_2 = self.profile.settings.get("emit_did_peer_2") or ( - conn_rec.their_did and conn_rec.their_did.startswith("did:peer:2") + respond_with_did_peer_2 = bool( + self.profile.settings.get("emit_did_peer_2") + or (conn_rec.their_did and conn_rec.their_did.startswith("did:peer:2")) ) - respond_with_did_peer_4 = self.profile.settings.get("emit_did_peer_4") or ( - conn_rec.their_did and conn_rec.their_did.startswith("did:peer:4") + respond_with_did_peer_4 = bool( + self.profile.settings.get("emit_did_peer_4") + or (conn_rec.their_did and conn_rec.their_did.startswith("did:peer:4")) ) - if conn_rec.my_did: - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.get_local_did(conn_rec.my_did) - did = my_info.did - elif respond_with_did_peer_4: - my_info = await self.create_did_peer_4(my_endpoints, mediation_records) - conn_rec.my_did = my_info.did - did = my_info.did - elif respond_with_did_peer_2: - my_info = await self.create_did_peer_2(my_endpoints, mediation_records) - conn_rec.my_did = my_info.did - did = my_info.did - elif use_public_did: + if use_public_did: async with self.profile.session() as session: wallet = session.inject(BaseWallet) - my_info = await wallet.get_public_did() - if not my_info: + public_info = await wallet.get_public_did() + if public_info: + conn_rec.my_did = public_info.did + else: raise DIDXManagerError("No public DID configured") - conn_rec.my_did = my_info.did - did = my_info.did - if not did.startswith("did:"): - did = f"did:sov:{did}" + if conn_rec.connection_protocol == DIDEX_1_0: + did, attach = await self._legacy_did_with_attached_doc( + conn_rec, + my_endpoints, + mediation_records, + invitation_key=conn_rec.invitation_key, + ) + response = DIDXResponse(did=did, did_doc_attach=attach) + response.assign_version("1.0") else: - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - my_info = await wallet.create_local_did( - method=SOV, - key_type=ED25519, + try: + did, attach = await self._qualified_did_with_fallback( + conn_rec, + my_endpoints, + mediation_records, + respond_with_did_peer_2, + respond_with_did_peer_4, + signing_key=conn_rec.invitation_key, ) - conn_rec.my_did = my_info.did - did = my_info.did + response = DIDXResponse(did=did, did_rotate_attach=attach) + except LegacyHandlingFallback: + did, attach = await self._legacy_did_with_attached_doc( + conn_rec, my_endpoints, mediation_records, conn_rec.invitation_key + ) + response = DIDXResponse(did=did, did_doc_attach=attach) # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( self.profile, conn_rec, mediation_records ) - if did.startswith("did:"): # It's a "real" resolvable did - # Omit DID Doc attachment if we're using a public DID or peer did - attach = AttachDecorator.data_base64_string(did) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - if conn_rec.invitation_key is not None: - await attach.data.sign(conn_rec.invitation_key, wallet) - else: - self._logger.warning("Invitation key was not set for connection") - attach = None - response = DIDXResponse(did=did, did_rotate_attach=attach) - else: - did_doc = await self.create_did_document( - my_info, - my_endpoints, - mediation_records=mediation_records, - ) - attach = AttachDecorator.data_base64(did_doc.serialize()) - async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - await attach.data.sign(conn_rec.invitation_key, wallet) - response = DIDXResponse(did=did, did_doc_attach=attach) - # Assign thread information response.assign_thread_from(request) response.assign_trace_from(request) @@ -707,7 +839,7 @@ async def create_response( ) if send_mediation_request: temp_mediation_mgr = MediationManager(self.profile) - _record, request = await temp_mediation_mgr.prepare_request( + _, request = await temp_mediation_mgr.prepare_request( conn_rec.connection_id ) responder = self.profile.inject(BaseResponder) @@ -859,6 +991,9 @@ async def accept_response( # create and send connection-complete message complete = DIDXComplete() complete.assign_thread_from(response) + if conn_rec.connection_protocol == DIDEX_1_0: + complete.assign_version("1.0") + responder = self.profile.inject_or(BaseResponder) if responder: await responder.send_reply(complete, connection_id=conn_rec.connection_id) @@ -969,6 +1104,8 @@ async def reject( "en": reason or "DID exchange rejected", }, ) + if conn_rec.connection_protocol == DIDEX_1_0: + report.assign_version("1.0") # TODO Delete the record? return report @@ -1043,6 +1180,7 @@ async def manager_error_to_problem_report( description={"en": e.message, "code": e.error_code} ) report.assign_thread_from(message) + report.assign_version_from(message) if message.did_doc_attach: try: # convert diddoc attachment to diddoc... diff --git a/aries_cloudagent/protocols/didexchange/v1_0/message_types.py b/aries_cloudagent/protocols/didexchange/v1_0/message_types.py index ecb038e1f4..a1b2e3873b 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/message_types.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/message_types.py @@ -1,12 +1,18 @@ """Message type identifiers for Connections.""" +from ....messaging.util import get_proto_default_version from ...didcomm_prefix import DIDCommPrefix +from ..definition import versions SPEC_URI = ( "https://github.com/hyperledger/aries-rfcs/tree/" "25464a5c8f8a17b14edaa4310393df6094ace7b0/features/0023-did-exchange" ) -ARIES_PROTOCOL = "didexchange/1.0" +# Default Version +DEFAULT_VERSION = get_proto_default_version(versions, 1) +DIDEX_1_0 = "didexchange/1.0" +DIDEX_1_1 = "didexchange/1.1" +ARIES_PROTOCOL = f"didexchange/{DEFAULT_VERSION}" # Message types DIDX_REQUEST = f"{ARIES_PROTOCOL}/request" diff --git a/aries_cloudagent/protocols/didexchange/v1_0/routes.py b/aries_cloudagent/protocols/didexchange/v1_0/routes.py index 4c3197bd57..8f81bd715f 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/routes.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/routes.py @@ -11,7 +11,7 @@ response_schema, ) -from marshmallow import fields +from marshmallow import fields, validate from ....admin.request_context import AdminRequestContext from ....connections.models.conn_record import ConnRecord, ConnRecordSchema @@ -28,7 +28,7 @@ from ....storage.error import StorageError, StorageNotFoundError from ....wallet.error import WalletError from .manager import DIDXManager, DIDXManagerError -from .message_types import SPEC_URI +from .message_types import DIDEX_1_0, DIDEX_1_1, SPEC_URI from .messages.request import DIDXRequest, DIDXRequestSchema @@ -106,6 +106,14 @@ class DIDXCreateRequestImplicitQueryStringSchema(OpenAPISchema): "example": "To issue a Faber College Graduate credential", }, ) + protocol = fields.Str( + required=False, + validate=validate.OneOf([DIDEX_1_0, DIDEX_1_1]), + metadata={ + "description": "Which DID Exchange Protocol version to use", + "example": "didexchange/1.0", + }, + ) class DIDXReceiveRequestImplicitQueryStringSchema(OpenAPISchema): @@ -223,7 +231,7 @@ async def didx_accept_invitation(request: web.BaseRequest): try: async with profile.session() as session: conn_rec = await ConnRecord.retrieve_by_id(session, connection_id) - request = await didx_mgr.create_request( + didx_request = await didx_mgr.create_request( conn_rec=conn_rec, my_label=my_label, my_endpoint=my_endpoint, @@ -235,7 +243,7 @@ async def didx_accept_invitation(request: web.BaseRequest): except (StorageError, WalletError, DIDXManagerError, BaseModelError) as err: raise web.HTTPBadRequest(reason=err.roll_up) from err - await outbound_handler(request, connection_id=conn_rec.connection_id) + await outbound_handler(didx_request, connection_id=conn_rec.connection_id) return web.json_response(result) @@ -258,7 +266,7 @@ async def didx_create_request_implicit(request: web.BaseRequest): """ context: AdminRequestContext = request["context"] - their_public_did = request.query.get("their_public_did") + their_public_did = request.query["their_public_did"] my_label = request.query.get("my_label") or None my_endpoint = request.query.get("my_endpoint") or None mediation_id = request.query.get("mediation_id") or None @@ -267,11 +275,12 @@ async def didx_create_request_implicit(request: web.BaseRequest): goal_code = request.query.get("goal_code") or None goal = request.query.get("goal") or None auto_accept = json.loads(request.query.get("auto_accept", "null")) + protocol = request.query.get("protocol") or None profile = context.profile didx_mgr = DIDXManager(profile) try: - request = await didx_mgr.create_request_implicit( + didx_request = await didx_mgr.create_request_implicit( their_public_did=their_public_did, my_label=my_label, my_endpoint=my_endpoint, @@ -281,18 +290,20 @@ async def didx_create_request_implicit(request: web.BaseRequest): goal_code=goal_code, goal=goal, auto_accept=auto_accept, + protocol=protocol, ) except StorageNotFoundError as err: raise web.HTTPNotFound(reason=err.roll_up) from err except (StorageError, WalletError, DIDXManagerError, BaseModelError) as err: raise web.HTTPBadRequest(reason=err.roll_up) from err - return web.json_response(request.serialize()) + return web.json_response(didx_request.serialize()) @docs( tags=["did-exchange"], summary="Receive request against public DID's implicit invitation", + deprecated=True, ) @querystring_schema(DIDXReceiveRequestImplicitQueryStringSchema()) @request_schema(DIDXRequestSchema()) @@ -311,21 +322,17 @@ async def didx_receive_request_implicit(request: web.BaseRequest): body = await request.json() alias = request.query.get("alias") - my_endpoint = request.query.get("my_endpoint") auto_accept = json.loads(request.query.get("auto_accept", "null")) - mediation_id = request.query.get("mediation_id") or None profile = context.profile didx_mgr = DIDXManager(profile) try: - request = DIDXRequest.deserialize(body) + didx_request = DIDXRequest.deserialize(body) conn_rec = await didx_mgr.receive_request( - request=request, - recipient_did=request._thread.pthid.split(":")[-1], + request=didx_request, + recipient_did=didx_request._thread.pthid.split(":")[-1], alias=alias, - my_endpoint=my_endpoint, auto_accept_implicit=auto_accept, - mediation_id=mediation_id, ) result = conn_rec.serialize() except StorageNotFoundError as err: diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index fbfdcde469..feb7020ce2 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -532,24 +532,6 @@ async def test_create_request_my_endpoint(self): ) assert didx_req - async def test_create_request_public_did(self): - mock_conn_rec = mock.MagicMock( - connection_id="dummy", - my_did=self.did_info.did, - their_did=TestConfig.test_target_did, - their_role=ConnRecord.Role.REQUESTER.rfc23, - state=ConnRecord.State.REQUEST.rfc23, - retrieve_invitation=mock.CoroutineMock( - return_value=mock.MagicMock( - services=[TestConfig.test_target_did], - ) - ), - save=mock.CoroutineMock(), - ) - - request = await self.manager.create_request(mock_conn_rec, use_public_did=True) - assert request.did_doc_attach is None - async def test_create_request_emit_did_peer_2(self): mock_conn_rec = mock.MagicMock( connection_id="dummy", @@ -577,7 +559,7 @@ async def test_create_request_emit_did_peer_2(self): mock.AsyncMock(return_value=mock_did_info), ) as mock_create_did_peer_2: request = await self.manager.create_request( - mock_conn_rec, use_public_did=True + mock_conn_rec, ) assert request.did_doc_attach is None mock_create_did_peer_2.assert_called_once() @@ -609,7 +591,7 @@ async def test_create_request_emit_did_peer_4(self): mock.AsyncMock(return_value=mock_did_info), ) as mock_create_did_peer_4: request = await self.manager.create_request( - mock_conn_rec, use_public_did=True + mock_conn_rec, ) assert request.did_doc_attach is None mock_create_did_peer_4.assert_called_once() @@ -716,7 +698,6 @@ async def test_receive_request_explicit_public_did(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=None, alias=None, auto_accept_implicit=None, ) @@ -751,7 +732,6 @@ async def test_receive_request_invi_not_found(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=TestConfig.test_verkey, - my_endpoint=None, alias=None, auto_accept_implicit=None, ) @@ -853,7 +833,6 @@ async def test_receive_request_public_did_no_did_doc_attachment(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=None, alias=None, auto_accept_implicit=None, ) @@ -916,7 +895,6 @@ async def test_receive_request_public_did_no_did_doc_attachment_no_did(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=None, alias=None, auto_accept_implicit=None, ) @@ -958,7 +936,6 @@ async def test_receive_request_public_did_x_not_public(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) @@ -1022,7 +999,6 @@ async def test_receive_request_public_did_x_wrong_did(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) @@ -1084,7 +1060,6 @@ async def test_receive_request_public_did_x_did_doc_attach_bad_sig(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) @@ -1126,7 +1101,6 @@ async def test_receive_request_public_did_no_public_invites(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) @@ -1198,7 +1172,6 @@ async def test_receive_request_public_did_no_auto_accept(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) @@ -1266,7 +1239,6 @@ async def test_receive_request_implicit_public_did_not_enabled(self): await self.manager.receive_request( request=mock_request, recipient_did=TestConfig.test_did, - my_endpoint=None, alias=None, auto_accept_implicit=None, ) @@ -1347,7 +1319,6 @@ async def test_receive_request_implicit_public_did(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=None, - my_endpoint=None, alias=None, auto_accept_implicit=None, ) @@ -1434,7 +1405,6 @@ async def test_receive_request_peer_did(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=TestConfig.test_verkey, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) @@ -1476,7 +1446,6 @@ async def test_receive_request_peer_did_not_found_x(self): request=mock_request, recipient_did=TestConfig.test_did, recipient_verkey=TestConfig.test_verkey, - my_endpoint=TestConfig.test_endpoint, alias="Alias", auto_accept_implicit=False, ) diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_disclose_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_disclose_handler.py index 10cc70d019..7b506e45a6 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_disclose_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_disclose_handler.py @@ -14,12 +14,12 @@ from ...messages.query import Query from ...models.discovery_record import V10DiscoveryExchangeRecord -TEST_MESSAGE_FAMILY = "TEST_FAMILY" -TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/MESSAGE" +TEST_MESSAGE_FAMILY = "doc/proto/1.0" +TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/message" @pytest.fixture() -def request_context() -> RequestContext: +def request_context(): ctx = RequestContext.test_context() ctx.connection_ready = True ctx.connection_record = mock.MagicMock(connection_id="test123") @@ -28,7 +28,7 @@ def request_context() -> RequestContext: class TestDiscloseHandler: @pytest.mark.asyncio - async def test_disclose(self, request_context): + async def test_disclose(self, request_context: RequestContext): registry = ProtocolRegistry() registry.register_message_types({TEST_MESSAGE_TYPE: object()}) request_context.injector.bind_instance(ProtocolRegistry, registry) diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py index 10ab3af132..16d5b7345a 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py @@ -10,12 +10,12 @@ from ...messages.disclose import Disclose from ...messages.query import Query -TEST_MESSAGE_FAMILY = "TEST_FAMILY" -TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/MESSAGE" +TEST_MESSAGE_FAMILY = "doc/proto/1.0" +TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/message" @pytest.fixture() -def request_context() -> RequestContext: +def request_context(): ctx = RequestContext.test_context() registry = ProtocolRegistry() registry.register_message_types({TEST_MESSAGE_TYPE: object()}) @@ -44,7 +44,7 @@ async def test_query_all(self, request_context): async def test_query_all_disclose_list_settings(self, request_context): profile = request_context.profile registry = profile.inject(ProtocolRegistry) - registry.register_message_types({"TEST_FAMILY_B/MESSAGE": object()}) + registry.register_message_types({"doc/proto-b/1.0/message": object()}) profile.context.injector.bind_instance(ProtocolRegistry, registry) profile.settings["disclose_protocol_list"] = [TEST_MESSAGE_FAMILY] query_msg = Query(query="*") @@ -75,7 +75,7 @@ async def test_receive_query_process_disclosed(self, request_context): mock_prepare_disclosed.return_value = [ {"test": "test"}, { - "pid": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/action-menu/1.0", + "pid": "https://didcomm.org/action-menu/1.0", "roles": ["provider"], }, ] @@ -85,9 +85,6 @@ async def test_receive_query_process_disclosed(self, request_context): result, target = messages[0] assert isinstance(result, Disclose) and result.protocols assert len(result.protocols) == 1 - assert ( - result.protocols[0]["pid"] - == "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/action-menu/1.0" - ) + assert result.protocols[0]["pid"] == "https://didcomm.org/action-menu/1.0" assert result.protocols[0]["roles"] == ["provider"] assert not target diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_disclosures_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_disclosures_handler.py index 6e190148b8..cd592aff3a 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_disclosures_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_disclosures_handler.py @@ -16,12 +16,12 @@ from ...messages.queries import Queries, QueryItem from ...models.discovery_record import V20DiscoveryExchangeRecord -TEST_MESSAGE_FAMILY = "TEST_FAMILY" -TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/MESSAGE" +TEST_MESSAGE_FAMILY = "doc/proto/1.0" +TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/message" @pytest.fixture() -def request_context() -> RequestContext: +def request_context(): ctx = RequestContext.test_context() ctx.connection_ready = True ctx.connection_record = mock.MagicMock(connection_id="test123") diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py index 1998a973e0..9560fbd0d8 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py @@ -22,8 +22,8 @@ from ...messages.disclosures import Disclosures from ...messages.queries import Queries, QueryItem -TEST_MESSAGE_FAMILY = "TEST_FAMILY" -TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/MESSAGE" +TEST_MESSAGE_FAMILY = "doc/proto/1.0" +TEST_MESSAGE_TYPE = TEST_MESSAGE_FAMILY + "/message" @pytest.fixture() @@ -88,7 +88,7 @@ async def test_queries_protocol_goal_code_all_disclose_list_settings( ): profile = request_context.profile protocol_registry = profile.inject(ProtocolRegistry) - protocol_registry.register_message_types({"TEST_FAMILY_B/MESSAGE": object()}) + protocol_registry.register_message_types({"doc/proto-b/1.0/message": object()}) profile.context.injector.bind_instance(ProtocolRegistry, protocol_registry) goal_code_registry = profile.inject(GoalCodeRegistry) goal_code_registry.register_controllers(pres_proof_v1_controller) @@ -139,7 +139,7 @@ async def test_receive_query_process_disclosed(self, request_context): mock_exec_protocol_query.return_value = [ {"test": "test"}, { - "pid": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/action-menu/1.0", + "pid": "https://didcomm.org/action-menu/1.0", "roles": ["provider"], }, ] @@ -150,8 +150,7 @@ async def test_receive_query_process_disclosed(self, request_context): result, target = messages[0] assert isinstance(result, Disclosures) assert ( - result.disclosures[0].get("id") - == "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/action-menu/1.0" + result.disclosures[0].get("id") == "https://didcomm.org/action-menu/1.0" ) assert result.disclosures[0].get("feature-type") == "protocol" assert result.disclosures[1].get("id") == "aries.vc" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index e334e8d8fa..382f7aecfb 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -9,7 +9,6 @@ from ....messaging.decorators.service_decorator import ServiceDecorator from ....core.event_bus import EventBus -from ....core.util import get_version_from_message from ....connections.base_manager import BaseConnectionManager from ....connections.models.conn_record import ConnRecord from ....core.error import BaseError @@ -205,6 +204,7 @@ async def create_invitation( handshake_protocols = [ DIDCommPrefix.qualify_current(hsp.name) for hsp in hs_protos or [] ] or None + # Handshake protocol list should be ordered by preference by caller connection_protocol = ( hs_protos[0].name if hs_protos and len(hs_protos) >= 1 else None ) @@ -358,7 +358,6 @@ async def create_invitation( ), alias=alias, connection_protocol=connection_protocol, - my_did=my_did, ) async with self.profile.session() as session: @@ -591,7 +590,7 @@ async def receive_invitation( # Try to reuse the connection. If not accepted sets the conn_rec to None if conn_rec and not invitation.requests_attach: oob_record = await self._handle_hanshake_reuse( - oob_record, conn_rec, get_version_from_message(invitation) + oob_record, conn_rec, invitation._version ) LOGGER.warning( @@ -895,13 +894,8 @@ async def _perform_handshake( invitation = oob_record.invitation supported_handshake_protocols = [ - HSProto.get(hsp) - for hsp in dict.fromkeys( - [ - DIDCommPrefix.unqualify(proto) - for proto in invitation.handshake_protocols - ] - ) + HSProto.get(DIDCommPrefix.unqualify(proto)) + for proto in invitation.handshake_protocols ] # Get the single service item @@ -947,7 +941,7 @@ async def _perform_handshake( conn_record = None for protocol in supported_handshake_protocols: # DIDExchange - if protocol is HSProto.RFC23: + if protocol is HSProto.RFC23 or protocol is HSProto.DIDEX_1_1: didx_mgr = DIDXManager(self.profile) conn_record = await didx_mgr.receive_invitation( invitation=invitation, @@ -955,6 +949,7 @@ async def _perform_handshake( auto_accept=auto_accept, alias=alias, mediation_id=mediation_id, + protocol=protocol.name, ) break # 0160 Connection @@ -1108,9 +1103,7 @@ async def receive_reuse_message( invi_msg_id = reuse_msg._thread.pthid reuse_msg_id = reuse_msg._thread_id - reuse_accept_msg = HandshakeReuseAccept( - version=get_version_from_message(reuse_msg) - ) + reuse_accept_msg = HandshakeReuseAccept(version=reuse_msg._version) reuse_accept_msg.assign_thread_id(thid=reuse_msg_id, pthid=invi_msg_id) connection_targets = await self.fetch_connection_targets(connection=conn_rec) diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py b/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py index e8fcd09a94..6ca6555b74 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py @@ -1,18 +1,18 @@ """Message and inner object type identifiers for Out of Band messages.""" -from ....core.util import get_proto_default_version +from ....messaging.util import get_proto_default_version from ...didcomm_prefix import DIDCommPrefix +from ..definition import versions + SPEC_URI = ( "https://github.com/hyperledger/aries-rfcs/tree/" "2da7fc4ee043effa3a9960150e7ba8c9a4628b68/features/0434-outofband" ) # Default Version -DEFAULT_VERSION = get_proto_default_version( - "aries_cloudagent.protocols.out_of_band.definition", 1 -) +DEFAULT_VERSION = get_proto_default_version(versions, 1) # Message types INVITATION = f"out-of-band/{DEFAULT_VERSION}/invitation" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py index 09a2dfe19e..53a4e04b75 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py @@ -1,9 +1,7 @@ """An invitation content message.""" -from collections import namedtuple from enum import Enum -from re import sub -from typing import Optional, Sequence, Text, Union +from typing import NamedTuple, Optional, Sequence, Set, Text, Union from urllib.parse import parse_qs, urljoin, urlparse from marshmallow import EXCLUDE, ValidationError, fields, post_dump, validates_schema @@ -17,61 +15,72 @@ from .....wallet.util import b64_to_bytes, bytes_to_b64 from ....connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO from ....didcomm_prefix import DIDCommPrefix -from ....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO +from ....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDEX_1_1, DIDEX_1_0 from ..message_types import DEFAULT_VERSION, INVITATION from .service import Service -HSProtoSpec = namedtuple("HSProtoSpec", "rfc name aka") + +class HSProtoSpec(NamedTuple): + """Handshake protocol specification.""" + + name: str + aka: Set[str] class HSProto(Enum): """Handshake protocol enum for invitation message.""" RFC160 = HSProtoSpec( - 160, CONN_PROTO, {"connection", "connections", "conn", "conns", "rfc160", "160", "old"}, ) RFC23 = HSProtoSpec( - 23, - DIDX_PROTO, - {"didexchange", "didx", "didex", "rfc23", "23", "new"}, + DIDEX_1_0, + { + "https://didcomm.org/didexchange/1.0", + "didexchange/1.0", + "didexchange", + "did-exchange", + "didx", + "didex", + "rfc23", + "rfc-23", + "23", + "new", + }, + ) + DIDEX_1_1 = HSProtoSpec( + DIDEX_1_1, + { + "https://didcomm.org/didexchange/1.1", + "didexchange/1.1", + }, ) @classmethod - def get(cls, label: Union[str, "HSProto"]) -> "HSProto": + def get(cls, label: Union[str, "HSProto"]) -> Optional["HSProto"]: """Get handshake protocol enum for label.""" if isinstance(label, str): for hsp in HSProto: if ( DIDCommPrefix.unqualify(label) == hsp.name - or sub("[^a-zA-Z0-9]+", "", label.lower()) in hsp.aka + or label.lower() in hsp.aka ): return hsp elif isinstance(label, HSProto): return label - elif isinstance(label, int): - for hsp in HSProto: - if hsp.rfc == label: - return hsp - return None - @property - def rfc(self) -> int: - """Accessor for RFC.""" - return self.value.rfc - @property def name(self) -> str: """Accessor for name.""" return self.value.name @property - def aka(self) -> int: + def aka(self) -> Set[str]: """Accessor for also-known-as.""" return self.value.aka diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py index b37cbd102b..1db6584ab0 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py @@ -8,7 +8,7 @@ from .....connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO from .....didcomm_prefix import DIDCommPrefix -from .....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO +from .....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDEX_1_1, DIDEX_1_0 from .....didexchange.v1_0.messages.request import DIDXRequest @@ -26,19 +26,16 @@ class TestHSProto(TestCase): def test_get(self): assert HSProto.get(HSProto.RFC160) is HSProto.RFC160 - assert HSProto.get(23) is HSProto.RFC23 assert HSProto.get("Old") is HSProto.RFC160 assert HSProto.get(DIDCommPrefix.qualify_current(CONN_PROTO)) is HSProto.RFC160 - assert HSProto.get(DIDX_PROTO) is HSProto.RFC23 + assert HSProto.get(DIDEX_1_0) is HSProto.RFC23 assert HSProto.get("did-exchange") is HSProto.RFC23 assert HSProto.get("RFC-23") is HSProto.RFC23 + assert HSProto.get(DIDEX_1_1) is HSProto.DIDEX_1_1 + assert HSProto.get("didexchange/1.1") is HSProto.DIDEX_1_1 assert HSProto.get("no such protocol") is None assert HSProto.get(None) is None - def test_properties(self): - assert HSProto.RFC160.rfc == 160 - assert HSProto.RFC23.name == DIDX_PROTO - class TestInvitationMessage(TestCase): def test_init(self): @@ -46,7 +43,7 @@ def test_init(self): invi_msg = InvitationMessage( comment="Hello", label="A label", - handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)], + handshake_protocols=[DIDCommPrefix.qualify_current(DIDEX_1_1)], services=[TEST_DID], ) assert invi_msg.services == [TEST_DID] @@ -56,7 +53,7 @@ def test_init(self): invi_msg = InvitationMessage( comment="Hello", label="A label", - handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)], + handshake_protocols=[DIDCommPrefix.qualify_current(DIDEX_1_1)], services=[service], version="1.0", ) @@ -117,7 +114,7 @@ def test_url_round_trip(self): invi_msg = InvitationMessage( comment="Hello", label="A label", - handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)], + handshake_protocols=[DIDCommPrefix.qualify_current(DIDEX_1_1)], services=[service], ) @@ -151,7 +148,7 @@ def test_assign_msg_type_version_to_model_inst(self): assert "1.2" in test_msg._type assert "1.1" in InvitationMessage.Meta.message_type test_req = DIDXRequest() - assert "1.0" in test_req._type + assert "1.1" in test_req._type assert "1.2" in test_msg._type assert "1.1" in InvitationMessage.Meta.message_type diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/models/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/models/tests/test_invitation.py index aafb44f9c0..50a7fade08 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/models/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/models/tests/test_invitation.py @@ -2,7 +2,7 @@ from .....didcomm_prefix import DIDCommPrefix -from .....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO +from .....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDEX_1_1 from ...messages.invitation import InvitationMessage @@ -34,7 +34,7 @@ def test_make_record(self): invi = InvitationMessage( comment="Hello", label="A label", - handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)], + handshake_protocols=[DIDCommPrefix.qualify_current(DIDEX_1_1)], services=[TEST_DID], ) data = { diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index fad679bbd8..05034e9685 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -15,7 +15,6 @@ from .....connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service from .....core.event_bus import EventBus from .....core.in_memory import InMemoryProfile -from .....core.util import get_version_from_message from .....core.oob_processor import OobMessageProcessor from .....did.did_key import DIDKey from .....messaging.decorators.attach_decorator import AttachDecorator @@ -935,7 +934,7 @@ async def test_create_handshake_reuse_msg(self): ) oob_record = await self.manager._create_handshake_reuse_message( - oob_record, self.test_conn_rec, get_version_from_message(invitation) + oob_record, self.test_conn_rec, invitation._version ) _, kwargs = self.responder.send.call_args @@ -1248,6 +1247,7 @@ async def test_receive_invitation_with_valid_mediation(self): auto_accept=None, alias=None, mediation_id=mediation_record._id, + protocol="didexchange/1.0", ) async def test_receive_invitation_with_invalid_mediation(self): @@ -1283,6 +1283,7 @@ async def test_receive_invitation_with_invalid_mediation(self): auto_accept=None, alias=None, mediation_id=None, + protocol="didexchange/1.0", ) async def test_receive_invitation_didx_services_with_service_block(self): @@ -1486,7 +1487,7 @@ async def test_receive_invitation_handshake_reuse(self): perform_handshake.assert_not_called() handle_handshake_reuse.assert_called_once_with( - ANY, test_exist_conn, get_version_from_message(oob_invitation) + ANY, test_exist_conn, oob_invitation._version ) assert result.state == OobRecord.STATE_ACCEPTED @@ -1547,7 +1548,7 @@ async def test_receive_invitation_handshake_reuse_failed(self): ) handle_handshake_reuse.assert_called_once_with( - ANY, test_exist_conn, get_version_from_message(oob_invitation) + ANY, test_exist_conn, oob_invitation._version ) perform_handshake.assert_called_once_with( oob_record=ANY, diff --git a/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes.py b/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes.py index cbce0e26e9..90ccebce43 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes.py @@ -414,7 +414,7 @@ async def test_present_proof_credentials_list_dif(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -496,7 +496,7 @@ async def test_present_proof_credentials_list_dif_one_of_filter(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -574,7 +574,7 @@ async def test_present_proof_credentials_dif_no_tag_query(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -652,7 +652,7 @@ async def test_present_proof_credentials_single_ldp_vp_claim_format(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -730,7 +730,7 @@ async def test_present_proof_credentials_double_ldp_vp_claim_format(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -786,7 +786,7 @@ async def test_present_proof_credentials_single_ldp_vp_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -849,7 +849,7 @@ async def test_present_proof_credentials_double_ldp_vp_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -909,7 +909,7 @@ async def test_present_proof_credentials_list_limit_disclosure_no_bbs(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -972,7 +972,7 @@ async def test_present_proof_credentials_no_ldp_vp(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -1033,7 +1033,7 @@ async def test_present_proof_credentials_list_schema_uri(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -1125,7 +1125,7 @@ async def test_present_proof_credentials_list_dif_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -1826,7 +1826,7 @@ async def test_present_proof_send_presentation_dif_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ diff --git a/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes_anoncreds.py b/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes_anoncreds.py index 982463ee04..4740a46f3e 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes_anoncreds.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/tests/test_routes_anoncreds.py @@ -419,7 +419,7 @@ async def test_present_proof_credentials_list_dif(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -501,7 +501,7 @@ async def test_present_proof_credentials_list_dif_one_of_filter(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -579,7 +579,7 @@ async def test_present_proof_credentials_dif_no_tag_query(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -657,7 +657,7 @@ async def test_present_proof_credentials_single_ldp_vp_claim_format(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -735,7 +735,7 @@ async def test_present_proof_credentials_double_ldp_vp_claim_format(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -791,7 +791,7 @@ async def test_present_proof_credentials_single_ldp_vp_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -854,7 +854,7 @@ async def test_present_proof_credentials_double_ldp_vp_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -914,7 +914,7 @@ async def test_present_proof_credentials_list_limit_disclosure_no_bbs(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -977,7 +977,7 @@ async def test_present_proof_credentials_no_ldp_vp(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -1038,7 +1038,7 @@ async def test_present_proof_credentials_list_schema_uri(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -1130,7 +1130,7 @@ async def test_present_proof_credentials_list_dif_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ @@ -1831,7 +1831,7 @@ async def test_present_proof_send_presentation_dif_error(self): role="prover", pres_proposal=None, pres_request={ - "@type": "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/present-proof/2.0/request-presentation", + "@type": "https://didcomm.org/present-proof/2.0/request-presentation", "@id": "6ae00c6c-87fa-495a-b546-5f5953817c92", "comment": "string", "formats": [ diff --git a/aries_cloudagent/protocols/tests/test_didcomm_prefix.py b/aries_cloudagent/protocols/tests/test_didcomm_prefix.py index daa403a9c9..ec741708b5 100644 --- a/aries_cloudagent/protocols/tests/test_didcomm_prefix.py +++ b/aries_cloudagent/protocols/tests/test_didcomm_prefix.py @@ -1,5 +1,3 @@ -from os import environ - from unittest import IsolatedAsyncioTestCase from ..didcomm_prefix import DIDCommPrefix @@ -7,23 +5,10 @@ class TestDIDCommPrefix(IsolatedAsyncioTestCase): def test_didcomm_prefix(self): - DIDCommPrefix.set({}) - assert environ.get("DIDCOMM_PREFIX") == DIDCommPrefix.OLD.value - - DIDCommPrefix.set({"emit_new_didcomm_prefix": True}) - assert environ.get("DIDCOMM_PREFIX") == DIDCommPrefix.NEW.value assert DIDCommPrefix.qualify_current("hello") == ( f"{DIDCommPrefix.NEW.value}/hello" ) - # No longer possible to have the arg `False` but leaving in test - # Still want to be able to receive the OLD format, just not emit it - DIDCommPrefix.set({"emit_new_didcomm_prefix": False}) - assert environ.get("DIDCOMM_PREFIX") == DIDCommPrefix.OLD.value - assert DIDCommPrefix.qualify_current("hello") == ( - f"{DIDCommPrefix.OLD.value}/hello" - ) - old_q_hello = DIDCommPrefix.OLD.qualify("hello") new_q_hello = DIDCommPrefix.NEW.qualify("hello") assert old_q_hello == f"{DIDCommPrefix.OLD.value}/hello" diff --git a/scripts/run_tests b/scripts/run_tests index 898d2361dc..8cea7dc3a9 100755 --- a/scripts/run_tests +++ b/scripts/run_tests @@ -42,4 +42,4 @@ fi $CONTAINER_RUNTIME run --rm -ti --name aries-cloudagent-runner \ --platform linux/amd64 \ -v "$(pwd)/../test-reports:/usr/src/app/test-reports:z" \ - $DOCKER_ARGS aries-cloudagent-test "$@" + $DOCKER_ARGS localhost/aries-cloudagent-test "$@"