Skip to content

Commit

Permalink
Identify as wallet node in blinded routes (#652)
Browse files Browse the repository at this point in the history
Adds a `EncodedNodeId.WithPublicKey.Wallet` to signal to the trampoline node that we are a wallet node.
This is not part of the spec as only our trampoline node will see it.
It removes the need for `ShortChannelId.peerId`.

Co-authored-by: Bastien Teinturier <31281497+t-bast@users.noreply.github.com>
  • Loading branch information
thomash-acinq and t-bast authored May 31, 2024
1 parent db29bfe commit a0d1999
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 51 deletions.
23 changes: 16 additions & 7 deletions src/commonMain/kotlin/fr/acinq/lightning/EncodedNodeId.kt
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
package fr.acinq.lightning

import fr.acinq.bitcoin.PublicKey
import fr.acinq.bitcoin.io.Input
import fr.acinq.bitcoin.io.Output
import fr.acinq.lightning.wire.LightningCodecs

sealed class EncodedNodeId {
/** Nodes are usually identified by their public key. */
data class Plain(val publicKey: PublicKey) : EncodedNodeId() {
override fun toString(): String = publicKey.toString()
sealed class WithPublicKey : EncodedNodeId() {
abstract val publicKey: PublicKey

/** Standard case where a node is identified by its public key. */
data class Plain(override val publicKey: PublicKey) : WithPublicKey() {
override fun toString(): String = publicKey.toString()
}

/**
* Wallet nodes are not part of the public graph, and may not have channels yet.
* Wallet providers are usually able to contact such nodes using push notifications or similar mechanisms.
*/
data class Wallet(override val publicKey: PublicKey) : WithPublicKey() {
override fun toString(): String = publicKey.toString()
}
}

/** For compactness, nodes may be identified by the shortChannelId of one of their public channels. */
Expand All @@ -17,6 +26,6 @@ sealed class EncodedNodeId {
}

companion object {
operator fun invoke(publicKey: PublicKey): EncodedNodeId = Plain(publicKey)
operator fun invoke(publicKey: PublicKey): WithPublicKey.Plain = WithPublicKey.Plain(publicKey)
}
}
18 changes: 10 additions & 8 deletions src/commonMain/kotlin/fr/acinq/lightning/message/OnionMessages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import fr.acinq.lightning.utils.toByteVector
import fr.acinq.lightning.wire.*

object OnionMessages {
data class IntermediateNode(val nodeId: PublicKey, val outgoingChannelId: ShortChannelId? = null, val padding: ByteVector? = null, val customTlvs: Set<GenericTlv> = setOf()) {
data class IntermediateNode(val nodeId: EncodedNodeId.WithPublicKey, val outgoingChannelId: ShortChannelId? = null, val padding: ByteVector? = null, val customTlvs: Set<GenericTlv> = setOf()) {
fun toTlvStream(nextNodeId: EncodedNodeId, nextBlinding: PublicKey? = null): TlvStream<RouteBlindingEncryptedDataTlv> {
val tlvs = setOfNotNull(
outgoingChannelId?.let { RouteBlindingEncryptedDataTlv.OutgoingChannelId(it) } ?: RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId),
Expand All @@ -24,13 +24,13 @@ object OnionMessages {

sealed class Destination {
data class BlindedPath(val route: RouteBlinding.BlindedRoute) : Destination()
data class Recipient(val nodeId: PublicKey, val pathId: ByteVector?, val padding: ByteVector? = null, val customTlvs: Set<GenericTlv> = setOf()) : Destination()
data class Recipient(val nodeId: EncodedNodeId.WithPublicKey, val pathId: ByteVector?, val padding: ByteVector? = null, val customTlvs: Set<GenericTlv> = setOf()) : Destination()

companion object {
operator fun invoke(contactInfo: OfferTypes.ContactInfo): Destination =
when (contactInfo) {
is OfferTypes.ContactInfo.BlindedPath -> BlindedPath(contactInfo.route)
is OfferTypes.ContactInfo.RecipientNodeId -> Recipient(contactInfo.nodeId, null)
is OfferTypes.ContactInfo.RecipientNodeId -> Recipient(EncodedNodeId.WithPublicKey.Plain(contactInfo.nodeId), null)
}
}
}
Expand All @@ -43,7 +43,9 @@ object OnionMessages {
return if (intermediateNodes.isEmpty()) {
listOf()
} else {
val intermediatePayloads = intermediateNodes.dropLast(1).zip(intermediateNodes.drop(1)).map { (current, next) -> current.toTlvStream(EncodedNodeId(next.nodeId)) }
val intermediatePayloads = intermediateNodes.dropLast(1).zip(intermediateNodes.drop(1)).map { (current, next) ->
current.toTlvStream(next.nodeId)
}
// The last intermediate node may contain a blinding override when the recipient is hidden behind a blinded path.
val lastPayload = intermediateNodes.last().toTlvStream(lastNodeId, lastBlinding)
(intermediatePayloads + lastPayload).map { RouteBlindingEncryptedData(it).write().byteVector() }
Expand All @@ -57,15 +59,15 @@ object OnionMessages {
): RouteBlinding.BlindedRoute {
return when (destination) {
is Destination.Recipient -> {
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, EncodedNodeId(destination.nodeId))
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, destination.nodeId)
val tlvs = setOfNotNull(
destination.padding?.let { RouteBlindingEncryptedDataTlv.Padding(it) },
destination.pathId?.let { RouteBlindingEncryptedDataTlv.PathId(it) }
)
val lastPayload = RouteBlindingEncryptedData(TlvStream(tlvs, destination.customTlvs)).write().toByteVector()
RouteBlinding.create(
blindingSecret,
intermediateNodes.map { it.nodeId } + destination.nodeId,
intermediateNodes.map { it.nodeId.publicKey } + destination.nodeId.publicKey,
intermediatePayloads + lastPayload
).route
}
Expand All @@ -80,7 +82,7 @@ object OnionMessages {
)
val routePrefix = RouteBlinding.create(
blindingSecret,
intermediateNodes.map { it.nodeId },
intermediateNodes.map { it.nodeId.publicKey },
intermediatePayloads
).route
RouteBlinding.BlindedRoute(
Expand Down Expand Up @@ -167,7 +169,7 @@ object OnionMessages {
null
}
is Either.Right -> when {
!decrypted.value.isLastPacket && relayInfo.value.nextNodeId == EncodedNodeId(privateKey.publicKey()) -> {
!decrypted.value.isLastPacket && relayInfo.value.nextNodeId == EncodedNodeId.WithPublicKey.Wallet(privateKey.publicKey()) -> {
// We may add ourselves to the route several times at the end to hide the real length of the route.
val nextMessage = OnionMessage(relayInfo.value.nextBlindingOverride ?: nextBlinding, decrypted.value.nextPacket)
decryptMessage(privateKey, nextMessage, logger)
Expand Down
13 changes: 7 additions & 6 deletions src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, v
val replyPathId = randomBytes32()
pendingInvoiceRequests[replyPathId] = PendingInvoiceRequest(payOffer, request)
// We add dummy hops to the reply path: this way the receiver only learns that we're at most 3 hops away from our peer.
val replyPathHops = listOf(remoteNodeId, nodeParams.nodeId, nodeParams.nodeId).map { IntermediateNode(it) }
val lastHop = Destination.Recipient(nodeParams.nodeId, replyPathId)
val replyPathHops = listOf(IntermediateNode(EncodedNodeId.WithPublicKey.Plain(remoteNodeId)), IntermediateNode(EncodedNodeId.WithPublicKey.Wallet(nodeParams.nodeId)), IntermediateNode(EncodedNodeId.WithPublicKey.Wallet(nodeParams.nodeId)))
val lastHop = Destination.Recipient(EncodedNodeId.WithPublicKey.Wallet(nodeParams.nodeId), replyPathId)
val replyPath = OnionMessages.buildRoute(randomKey(), replyPathHops, lastHop)
val messageContent = TlvStream(OnionMessagePayloadTlv.ReplyPath(replyPath), OnionMessagePayloadTlv.InvoiceRequest(request.records))
val invoiceRequests = payOffer.offer.contactInfos.mapNotNull { contactInfo ->
Expand Down Expand Up @@ -154,7 +154,7 @@ class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, v
)
val remoteNodePayload = RouteBlindingEncryptedData(
TlvStream(
RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId.peerId(nodeParams.nodeId)),
RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(nodeParams.nodeId)),
RouteBlindingEncryptedDataTlv.PaymentRelay(paymentInfo.cltvExpiryDelta, paymentInfo.feeProportionalMillionths, paymentInfo.feeBase),
RouteBlindingEncryptedDataTlv.PaymentConstraints((paymentInfo.cltvExpiryDelta + nodeParams.maxFinalCltvExpiryDelta).toCltvExpiry(currentBlockHeight.toLong()), paymentInfo.minHtlc)
)
Expand Down Expand Up @@ -186,11 +186,12 @@ class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, v
private fun intermediateNodes(destination: Destination): List<IntermediateNode> {
val needIntermediateHop = when (destination) {
is Destination.BlindedPath -> when (val introduction = destination.route.introductionNodeId) {
is EncodedNodeId.Plain -> introduction.publicKey != remoteNodeId
is EncodedNodeId.WithPublicKey.Plain -> introduction.publicKey != remoteNodeId
is EncodedNodeId.WithPublicKey.Wallet -> true
is EncodedNodeId.ShortChannelIdDir -> true // we don't have access to the graph data and rely on our peer to resolve the scid
}
is Destination.Recipient -> destination.nodeId != remoteNodeId
is Destination.Recipient -> destination.nodeId.publicKey != remoteNodeId
}
return if (needIntermediateHop) listOf(IntermediateNode(remoteNodeId)) else listOf()
return if (needIntermediateHop) listOf(IntermediateNode(EncodedNodeId.WithPublicKey.Plain(remoteNodeId))) else listOf()
}
}
34 changes: 22 additions & 12 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/LightningCodecs.kt
Original file line number Diff line number Diff line change
Expand Up @@ -225,25 +225,35 @@ object LightningCodecs {
}

fun encodedNodeId(input: Input): EncodedNodeId {
val firstByte = byte(input)
if (firstByte == 0 || firstByte == 1) {
val isNode1 = firstByte == 0
val scid = ShortChannelId(int64(input))
return EncodedNodeId.ShortChannelIdDir(isNode1, scid)
} else if (firstByte == 2 || firstByte == 3) {
val publicKey = PublicKey(ByteArray(1) { firstByte.toByte() } + bytes(input, 32))
return EncodedNodeId.Plain(publicKey)
} else {
throw IllegalArgumentException("unexpected first byte: $firstByte")
return when (val firstByte = byte(input)) {
0, 1 -> {
val isNode1 = firstByte == 0
val scid = ShortChannelId(int64(input))
EncodedNodeId.ShortChannelIdDir(isNode1, scid)
}
2, 3 -> {
val publicKey = PublicKey(ByteArray(1) { firstByte.toByte() } + bytes(input, 32))
EncodedNodeId.WithPublicKey.Plain(publicKey)
}
4, 5 -> {
val publicKey = PublicKey(ByteArray(1) { (firstByte - 2).toByte() } + bytes(input, 32))
EncodedNodeId.WithPublicKey.Wallet(publicKey)
}
else -> throw IllegalArgumentException("unexpected first byte: $firstByte")
}
}

fun writeEncodedNodeId(input: EncodedNodeId, out: Output): Unit = when (input) {
is EncodedNodeId.Plain -> writeBytes(input.publicKey.value, out)
is EncodedNodeId.WithPublicKey.Plain -> writeBytes(input.publicKey.value, out)
is EncodedNodeId.ShortChannelIdDir -> {
writeByte(if (input.isNode1) 0 else 1, out)
writeInt64(input.scid.toLong(), out)
}
is EncodedNodeId.WithPublicKey.Wallet -> {
val firstByte = input.publicKey.value[0]
writeByte(firstByte + 2, out)
writeBytes(input.publicKey.value.drop(1), out)
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ object OfferTypes {
customTlvs: Set<GenericTlv> = setOf()
): Offer {
if (description == null) require(amount == null) { "an offer description must be provided if the amount isn't null" }
val path = OnionMessages.buildRoute(blindingSecret, listOf(OnionMessages.IntermediateNode(trampolineNode.id, ShortChannelId.peerId(nodeParams.nodeId))), OnionMessages.Destination.Recipient(nodeParams.nodeId, null))
val path = OnionMessages.buildRoute(blindingSecret, listOf(OnionMessages.IntermediateNode(EncodedNodeId.WithPublicKey.Plain(trampolineNode.id))), OnionMessages.Destination.Recipient(EncodedNodeId.WithPublicKey.Wallet(nodeParams.nodeId), null))
val tlvs: Set<OfferTlv> = setOfNotNull(
if (nodeParams.chainHash != Block.LivenetGenesisBlock.hash) OfferChains(listOf(nodeParams.chainHash)) else null,
amount?.let { OfferAmount(it) },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class OnionMessagesTestsCommon {
val blindingSecret = randomKey()
val destination = randomKey()
val pathId = randomBytes32()
val message = buildMessage(sessionKey, blindingSecret, listOf(), Recipient(destination.publicKey(), pathId), TlvStream.empty())
val message = buildMessage(sessionKey, blindingSecret, listOf(), Recipient(EncodedNodeId(destination.publicKey()), pathId), TlvStream.empty())
assertIs<Either.Right<OnionMessage>>(message)

val decrypted = decryptMessage(destination, message.value, logger)
Expand Down Expand Up @@ -117,9 +117,9 @@ class OnionMessagesTestsCommon {
val onionForAlice = OnionMessage(blindingSecret.publicKey(), packet)

// Building the onion with functions from `OnionMessages`
val replyPath = buildRoute(blindingOverride, listOf(IntermediateNode(carol.publicKey(), padding = ByteVector.fromHex("0000000000000000000000000000000000000000000000000000000000000000000000"))), Recipient(dave.publicKey(), ByteVector.fromHex("01234567")))
val replyPath = buildRoute(blindingOverride, listOf(IntermediateNode(EncodedNodeId(carol.publicKey()), padding = ByteVector.fromHex("0000000000000000000000000000000000000000000000000000000000000000000000"))), Recipient(EncodedNodeId(dave.publicKey()), ByteVector.fromHex("01234567")))
assertEquals(routeFromCarol, replyPath)
val message = buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(alice.publicKey()), IntermediateNode(bob.publicKey())), BlindedPath(replyPath), TlvStream.empty())
val message = buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(EncodedNodeId(alice.publicKey())), IntermediateNode(EncodedNodeId(bob.publicKey()))), BlindedPath(replyPath), TlvStream.empty())
assertEquals(Either.Right(onionForAlice), message)

// Checking that the onion is relayed properly
Expand Down Expand Up @@ -211,7 +211,7 @@ class OnionMessagesTestsCommon {
val blindingSecret = randomKey()
val blindingOverride = randomKey()
val destination = randomKey()
val replyPath = buildRoute(blindingOverride, listOf(IntermediateNode(destination.publicKey())), Recipient(destination.publicKey(), pathId = ByteVector.fromHex("01234567")))
val replyPath = buildRoute(blindingOverride, listOf(IntermediateNode(EncodedNodeId(destination.publicKey()))), Recipient(EncodedNodeId(destination.publicKey()), pathId = ByteVector.fromHex("01234567")))
assertEquals(blindingOverride.publicKey(), replyPath.blindingKey)
assertEquals(EncodedNodeId(destination.publicKey()), replyPath.introductionNodeId)
val message = buildMessage(sessionKey, blindingSecret, listOf(), BlindedPath(replyPath), TlvStream.empty()).right!!
Expand All @@ -232,7 +232,7 @@ class OnionMessagesTestsCommon {
val sessionKey = randomKey()
val blindingSecret = randomKey()
val pathId = randomBytes(65201).toByteVector()
val messageForAlice = buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(alice.publicKey()), IntermediateNode(bob.publicKey())), Recipient(carol.publicKey(), pathId), TlvStream.empty()).right!!
val messageForAlice = buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(EncodedNodeId(alice.publicKey())), IntermediateNode(EncodedNodeId(bob.publicKey()))), Recipient(EncodedNodeId(carol.publicKey()), pathId), TlvStream.empty()).right!!
// This message should use the maximum size allowed for lightning messages, without overflowing it.
// Note that we leave 2 bytes for the message length, resulting in a total packet of 65535 bytes.
assertEquals(65533, messageForAlice.write().size)
Expand All @@ -257,7 +257,7 @@ class OnionMessagesTestsCommon {
val pathId = randomBytes(65202).toByteVector()
assertEquals(
Either.Left(OnionMessages.MessageTooLarge(65433)),
buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(alice.publicKey()), IntermediateNode(bob.publicKey())), Recipient(carol.publicKey(), pathId), TlvStream.empty())
buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(EncodedNodeId(alice.publicKey())), IntermediateNode(EncodedNodeId(bob.publicKey()))), Recipient(EncodedNodeId(carol.publicKey()), pathId), TlvStream.empty())
)
}

Expand All @@ -271,7 +271,7 @@ class OnionMessagesTestsCommon {
val sessionKey = randomKey()
val blindingSecret = randomKey()
val pathId = randomBytes(64).toByteVector()
val messageForAlice = buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(alice.publicKey(), alice2bob), IntermediateNode(bob.publicKey(), bob2carol)), Recipient(carol.publicKey(), pathId), TlvStream.empty()).right!!
val messageForAlice = buildMessage(sessionKey, blindingSecret, listOf(IntermediateNode(EncodedNodeId(alice.publicKey()), alice2bob), IntermediateNode(EncodedNodeId(bob.publicKey()), bob2carol)), Recipient(EncodedNodeId(carol.publicKey()), pathId), TlvStream.empty()).right!!

// The onion is relayed properly:
val (outgoingChannelId1, onionForBob) = relayMessage(alice, messageForAlice)
Expand Down
Loading

0 comments on commit a0d1999

Please sign in to comment.