Skip to content

Commit

Permalink
fix: support all minor versions handshake (#1711)
Browse files Browse the repository at this point in the history
Signed-off-by: Timo Glastra <timo@animo.id>
  • Loading branch information
TimoGlastra committed Jan 30, 2024
1 parent c7886cb commit 40063e0
Show file tree
Hide file tree
Showing 14 changed files with 376 additions and 100 deletions.
24 changes: 18 additions & 6 deletions packages/core/src/agent/MessageHandlerRegistry.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type { AgentMessage } from './AgentMessage'
import type { MessageHandler } from './MessageHandler'
import type { ParsedDidCommProtocolUri } from '../utils/messageType'

import { injectable } from 'tsyringe'

import { canHandleMessageType, parseMessageType } from '../utils/messageType'
import { supportsIncomingDidCommProtocolUri, canHandleMessageType, parseMessageType } from '../utils/messageType'

@injectable()
export class MessageHandlerRegistry {
Expand Down Expand Up @@ -47,13 +48,24 @@ export class MessageHandlerRegistry {
* Returns array of protocol IDs that dispatcher is able to handle.
* Protocol ID format is PIURI specified at https://github.com/hyperledger/aries-rfcs/blob/main/concepts/0003-protocols/README.md#piuri.
*/
public get supportedProtocols() {
return Array.from(new Set(this.supportedMessageTypes.map((m) => m.protocolUri)))
public get supportedProtocolUris() {
const seenProtocolUris = new Set<string>()

const protocolUris: ParsedDidCommProtocolUri[] = this.supportedMessageTypes
.filter((m) => {
const has = seenProtocolUris.has(m.protocolUri)
seenProtocolUris.add(m.protocolUri)
return !has
})
// eslint-disable-next-line @typescript-eslint/no-unused-vars
.map(({ messageName, messageTypeUri, ...parsedProtocolUri }) => parsedProtocolUri)

return protocolUris
}

public filterSupportedProtocolsByMessageFamilies(messageFamilies: string[]) {
return this.supportedProtocols.filter((protocolId) =>
messageFamilies.find((messageFamily) => protocolId.startsWith(messageFamily))
public filterSupportedProtocolsByProtocolUris(parsedProtocolUris: ParsedDidCommProtocolUri[]) {
return this.supportedProtocolUris.filter((supportedProtocol) =>
parsedProtocolUris.some((p) => supportsIncomingDidCommProtocolUri(supportedProtocol, p))
)
}
}
28 changes: 14 additions & 14 deletions packages/core/src/agent/__tests__/MessageHandlerRegistry.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { MessageHandler } from '../MessageHandler'

import { parseMessageType } from '../../utils/messageType'
import { parseDidCommProtocolUri, parseMessageType } from '../../utils/messageType'
import { AgentMessage } from '../AgentMessage'
import { MessageHandlerRegistry } from '../MessageHandlerRegistry'

Expand Down Expand Up @@ -74,36 +74,36 @@ describe('MessageHandlerRegistry', () => {

describe('supportedProtocols', () => {
test('return all supported message protocols URIs', async () => {
const messageTypes = messageHandlerRegistry.supportedProtocols
const messageTypes = messageHandlerRegistry.supportedProtocolUris

expect(messageTypes).toEqual([
'https://didcomm.org/connections/1.0',
'https://didcomm.org/notification/1.0',
'https://didcomm.org/issue-credential/1.0',
'https://didcomm.org/fake-protocol/1.5',
parseDidCommProtocolUri('https://didcomm.org/connections/1.0'),
parseDidCommProtocolUri('https://didcomm.org/notification/1.0'),
parseDidCommProtocolUri('https://didcomm.org/issue-credential/1.0'),
parseDidCommProtocolUri('https://didcomm.org/fake-protocol/1.5'),
])
})
})

describe('filterSupportedProtocolsByMessageFamilies', () => {
describe('filterSupportedProtocolsByProtocolUris', () => {
it('should return empty array when input is empty array', async () => {
const supportedProtocols = messageHandlerRegistry.filterSupportedProtocolsByMessageFamilies([])
const supportedProtocols = messageHandlerRegistry.filterSupportedProtocolsByProtocolUris([])
expect(supportedProtocols).toEqual([])
})

it('should return empty array when input contains only unsupported protocol', async () => {
const supportedProtocols = messageHandlerRegistry.filterSupportedProtocolsByMessageFamilies([
'https://didcomm.org/unsupported-protocol/1.0',
const supportedProtocols = messageHandlerRegistry.filterSupportedProtocolsByProtocolUris([
parseDidCommProtocolUri('https://didcomm.org/unsupported-protocol/1.0'),
])
expect(supportedProtocols).toEqual([])
})

it('should return array with only supported protocol when input contains supported and unsupported protocol', async () => {
const supportedProtocols = messageHandlerRegistry.filterSupportedProtocolsByMessageFamilies([
'https://didcomm.org/connections',
'https://didcomm.org/didexchange',
const supportedProtocols = messageHandlerRegistry.filterSupportedProtocolsByProtocolUris([
parseDidCommProtocolUri('https://didcomm.org/connections/1.0'),
parseDidCommProtocolUri('https://didcomm.org/didexchange/1.0'),
])
expect(supportedProtocols).toEqual(['https://didcomm.org/connections/1.0'])
expect(supportedProtocols).toEqual([parseDidCommProtocolUri('https://didcomm.org/connections/1.0')])
})
})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
/**
* Enum values should be sorted based on order of preference. Values will be
* included in this order when creating out of band invitations.
*/
export enum HandshakeProtocol {
Connections = 'https://didcomm.org/connections/1.0',
DidExchange = 'https://didcomm.org/didexchange/1.1',
DidExchange = 'https://didcomm.org/didexchange/1.x',
Connections = 'https://didcomm.org/connections/1.x',
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import type { ConnectionMetadata } from './ConnectionMetadataTypes'
import type { TagsBase } from '../../../storage/BaseRecord'
import type { HandshakeProtocol } from '../models'
import type { ConnectionType } from '../models/ConnectionType'

import { Transform } from 'class-transformer'

import { AriesFrameworkError } from '../../../error'
import { BaseRecord } from '../../../storage/BaseRecord'
import { uuid } from '../../../utils/uuid'
import { rfc0160StateFromDidExchangeState, DidExchangeRole, DidExchangeState } from '../models'
import { rfc0160StateFromDidExchangeState, DidExchangeRole, DidExchangeState, HandshakeProtocol } from '../models'

export interface ConnectionRecordProps {
id?: string
Expand Down Expand Up @@ -46,10 +47,7 @@ export type DefaultConnectionTags = {
previousTheirDids?: Array<string>
}

export class ConnectionRecord
extends BaseRecord<DefaultConnectionTags, CustomConnectionTags, ConnectionMetadata>
implements ConnectionRecordProps
{
export class ConnectionRecord extends BaseRecord<DefaultConnectionTags, CustomConnectionTags, ConnectionMetadata> {
public state!: DidExchangeState
public role!: DidExchangeRole

Expand All @@ -65,6 +63,18 @@ export class ConnectionRecord
public threadId?: string
public mediatorId?: string
public errorMessage?: string

// We used to store connection record using major.minor version, but we now
// only store the major version, storing .x for the minor version. We have this
// transformation so we don't have to migrate the data in the database.
@Transform(
({ value }) => {
if (!value || typeof value !== 'string' || value.endsWith('.x')) return value
return value.split('.').slice(0, -1).join('.') + '.x'
},

{ toClassOnly: true }
)
public protocol?: HandshakeProtocol
public outOfBandId?: string
public invitationDid?: string
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { DidExchangeRole, DidExchangeState } from '../../models'
import { JsonTransformer } from '../../../../utils'
import { DidExchangeRole, DidExchangeState, HandshakeProtocol } from '../../models'
import { ConnectionRecord } from '../ConnectionRecord'

describe('ConnectionRecord', () => {
Expand Down Expand Up @@ -30,4 +31,26 @@ describe('ConnectionRecord', () => {
})
})
})

it('should transform handshake protocol with minor version to .x', () => {
const connectionRecord = JsonTransformer.fromJSON(
{
protocol: 'https://didcomm.org/didexchange/1.0',
},
ConnectionRecord
)

expect(connectionRecord.protocol).toEqual(HandshakeProtocol.DidExchange)
})

it('should not transform handshake protocol when minor version is .x', () => {
const connectionRecord = JsonTransformer.fromJSON(
{
protocol: 'https://didcomm.org/didexchange/1.x',
},
ConnectionRecord
)

expect(connectionRecord.protocol).toEqual(HandshakeProtocol.DidExchange)
})
})
106 changes: 75 additions & 31 deletions packages/core/src/modules/oob/OutOfBandApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ import { AriesFrameworkError } from '../../error'
import { Logger } from '../../logger'
import { inject, injectable } from '../../plugins'
import { JsonEncoder, JsonTransformer } from '../../utils'
import { parseMessageType, supportsIncomingMessageType } from '../../utils/messageType'
import {
parseDidCommProtocolUri,
parseMessageType,
supportsIncomingDidCommProtocolUri,
supportsIncomingMessageType,
} from '../../utils/messageType'
import { parseInvitationShortUrl } from '../../utils/parseInvitation'
import { ConnectionsApi, DidExchangeState, HandshakeProtocol } from '../connections'
import { DidCommDocumentService } from '../didcomm'
Expand Down Expand Up @@ -166,16 +171,17 @@ export class OutOfBandApi {
throw new AriesFrameworkError("Attribute 'multiUseInvitation' can not be 'true' when 'messages' is defined.")
}

let handshakeProtocols
let handshakeProtocols: string[] | undefined
if (handshake) {
// Find supported handshake protocol preserving the order of handshake protocols defined
// by agent
// Assert ALL custom handshake protocols are supported
if (customHandshakeProtocols) {
this.assertHandshakeProtocols(customHandshakeProtocols)
handshakeProtocols = customHandshakeProtocols
} else {
handshakeProtocols = this.getSupportedHandshakeProtocols()
this.assertHandshakeProtocolsSupported(customHandshakeProtocols)
}

// Find supported handshake protocol preserving the order of handshake protocols defined by agent or in config
handshakeProtocols = this.getSupportedHandshakeProtocols(customHandshakeProtocols).map(
(p) => p.parsedProtocolUri.protocolUri
)
}

const routing = config.routing ?? (await this.routingService.getRouting(this.agentContext, {}))
Expand Down Expand Up @@ -365,11 +371,15 @@ export class OutOfBandApi {
* @returns out-of-band record and connection record if one has been created.
*/
public async receiveImplicitInvitation(config: ReceiveOutOfBandImplicitInvitationConfig) {
const handshakeProtocols = this.getSupportedHandshakeProtocols(
config.handshakeProtocols ?? [HandshakeProtocol.DidExchange]
).map((p) => p.parsedProtocolUri.protocolUri)

const invitation = new OutOfBandInvitation({
id: config.did,
label: config.label ?? '',
services: [config.did],
handshakeProtocols: config.handshakeProtocols ?? [HandshakeProtocol.DidExchange],
handshakeProtocols,
})

return this._receiveInvitation(invitation, { ...config, isImplicit: true })
Expand Down Expand Up @@ -580,13 +590,13 @@ export class OutOfBandApi {
this.logger.debug('Connection does not exist or reuse is disabled. Creating a new connection.')
// Find first supported handshake protocol preserving the order of handshake protocols
// defined by `handshake_protocols` attribute in the invitation message
const handshakeProtocol = this.getFirstSupportedProtocol(handshakeProtocols)
const firstSupportedProtocol = this.getFirstSupportedProtocol(handshakeProtocols)
connectionRecord = await this.connectionsApi.acceptOutOfBandInvitation(outOfBandRecord, {
label,
alias,
imageUrl,
autoAcceptConnection,
protocol: handshakeProtocol,
protocol: firstSupportedProtocol.handshakeProtocol,
routing,
ourDid,
})
Expand Down Expand Up @@ -699,47 +709,81 @@ export class OutOfBandApi {
return this.outOfBandService.deleteById(this.agentContext, outOfBandId)
}

private assertHandshakeProtocols(handshakeProtocols: HandshakeProtocol[]) {
private assertHandshakeProtocolsSupported(handshakeProtocols: HandshakeProtocol[]) {
if (!this.areHandshakeProtocolsSupported(handshakeProtocols)) {
const supportedProtocols = this.getSupportedHandshakeProtocols()
const supportedProtocols = this.getSupportedHandshakeProtocols().map((p) => p.handshakeProtocol)
throw new AriesFrameworkError(
`Handshake protocols [${handshakeProtocols}] are not supported. Supported protocols are [${supportedProtocols}]`
)
}
}

private areHandshakeProtocolsSupported(handshakeProtocols: HandshakeProtocol[]) {
const supportedProtocols = this.getSupportedHandshakeProtocols()
return handshakeProtocols.every((p) => supportedProtocols.includes(p))
const supportedProtocols = this.getSupportedHandshakeProtocols(handshakeProtocols)
return supportedProtocols.length === handshakeProtocols.length
}

private getSupportedHandshakeProtocols(): HandshakeProtocol[] {
// TODO: update to featureRegistry
const handshakeMessageFamilies = ['https://didcomm.org/didexchange', 'https://didcomm.org/connections']
const handshakeProtocols =
this.messageHandlerRegistry.filterSupportedProtocolsByMessageFamilies(handshakeMessageFamilies)
private getSupportedHandshakeProtocols(limitToHandshakeProtocols?: HandshakeProtocol[]) {
const allHandshakeProtocols = limitToHandshakeProtocols ?? Object.values(HandshakeProtocol)

// Replace .x in the handshake protocol with .0 to allow it to be parsed
const parsedHandshakeProtocolUris = allHandshakeProtocols.map((h) => ({
handshakeProtocol: h,
parsedProtocolUri: parseDidCommProtocolUri(h.replace('.x', '.0')),
}))

if (handshakeProtocols.length === 0) {
// Now find all handshake protocols that start with the protocol uri without minor version '<base-uri>/<protocol-name>/<major-version>.'
const supportedHandshakeProtocols = this.messageHandlerRegistry.filterSupportedProtocolsByProtocolUris(
parsedHandshakeProtocolUris.map((p) => p.parsedProtocolUri)
)

if (supportedHandshakeProtocols.length === 0) {
throw new AriesFrameworkError('There is no handshake protocol supported. Agent can not create a connection.')
}

// Order protocols according to `handshakeMessageFamilies` array
const orderedProtocols = handshakeMessageFamilies
.map((messageFamily) => handshakeProtocols.find((p) => p.startsWith(messageFamily)))
.filter((item): item is string => !!item)
// Order protocols according to `parsedHandshakeProtocolUris` array (order of preference)
const orderedProtocols = parsedHandshakeProtocolUris
.map((p) => {
const found = supportedHandshakeProtocols.find((s) =>
supportsIncomingDidCommProtocolUri(s, p.parsedProtocolUri)
)
// We need to override the parsedProtocolUri with the one from the supported protocols, as we used `.0` as the minor
// version before. But when we return it, we want to return the correct minor version that we actually support
return found ? { ...p, parsedProtocolUri: found } : null
})
.filter((p): p is NonNullable<typeof p> => p !== null)

return orderedProtocols as HandshakeProtocol[]
return orderedProtocols
}

private getFirstSupportedProtocol(handshakeProtocols: HandshakeProtocol[]) {
/**
* Get the first supported protocol based on the handshake protocols provided in the out of band
* invitation.
*
* Returns an enum value from {@link HandshakeProtocol} or throw an error if no protocol is supported.
* Minor versions are ignored when selecting a supported protocols, so if the `outOfBandInvitationSupportedProtocolsWithMinorVersion`
* value is `https://didcomm.org/didexchange/1.0` and the agent supports `https://didcomm.org/didexchange/1.1`
* this will be fine, and the returned value will be {@link HandshakeProtocol.DidExchange}.
*/
private getFirstSupportedProtocol(protocolUris: string[]) {
const supportedProtocols = this.getSupportedHandshakeProtocols()
const handshakeProtocol = handshakeProtocols.find((p) => supportedProtocols.includes(p))
if (!handshakeProtocol) {
const parsedProtocolUris = protocolUris.map(parseDidCommProtocolUri)

const firstSupportedProtocol = supportedProtocols.find((supportedProtocol) =>
parsedProtocolUris.find((parsedProtocol) =>
supportsIncomingDidCommProtocolUri(supportedProtocol.parsedProtocolUri, parsedProtocol)
)
)

if (!firstSupportedProtocol) {
throw new AriesFrameworkError(
`Handshake protocols [${handshakeProtocols}] are not supported. Supported protocols are [${supportedProtocols}]`
`Handshake protocols [${protocolUris}] are not supported. Supported protocols are [${supportedProtocols.map(
(p) => p.handshakeProtocol
)}]`
)
}
return handshakeProtocol

return firstSupportedProtocol
}

private async findExistingConnection(outOfBandInvitation: OutOfBandInvitation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@ describe('out of band', () => {
expect(senderReceiverConnection).toBeConnectedWith(receiverSenderConnection)
})

test(`make a connection with self using https://didcomm.org/didexchange/1.1 protocol, but invitation using https://didcomm.org/didexchange/1.0`, async () => {
const outOfBandRecord = await faberAgent.oob.createInvitation()

const { outOfBandInvitation } = outOfBandRecord
outOfBandInvitation.handshakeProtocols = ['https://didcomm.org/didexchange/1.0']
const urlMessage = outOfBandInvitation.toUrl({ domain: 'http://example.com' })

// eslint-disable-next-line prefer-const
let { outOfBandRecord: receivedOutOfBandRecord, connectionRecord: receiverSenderConnection } =
await faberAgent.oob.receiveInvitationFromUrl(urlMessage)
expect(receivedOutOfBandRecord.state).toBe(OutOfBandState.PrepareResponse)

receiverSenderConnection = await faberAgent.connections.returnWhenIsConnected(receiverSenderConnection!.id)
expect(receiverSenderConnection.state).toBe(DidExchangeState.Completed)

let [senderReceiverConnection] = await faberAgent.connections.findAllByOutOfBandId(outOfBandRecord.id)
senderReceiverConnection = await faberAgent.connections.returnWhenIsConnected(senderReceiverConnection.id)
expect(senderReceiverConnection.state).toBe(DidExchangeState.Completed)
expect(senderReceiverConnection.protocol).toBe(HandshakeProtocol.DidExchange)

expect(receiverSenderConnection).toBeConnectedWith(senderReceiverConnection!)
expect(senderReceiverConnection).toBeConnectedWith(receiverSenderConnection)
})

test(`make a connection with self using ${HandshakeProtocol.Connections} protocol`, async () => {
const outOfBandRecord = await faberAgent.oob.createInvitation({
handshakeProtocols: [HandshakeProtocol.Connections],
Expand Down
Loading

0 comments on commit 40063e0

Please sign in to comment.