Skip to content

Commit

Permalink
Allow relaying messages to self (#2834)
Browse files Browse the repository at this point in the history
Allow sending messages to self

Fixes corner cases caused by compact encoding of node ids. Every message to be relayed now follows the same path and `MessageRelay` can relay to self.
  • Loading branch information
thomash-acinq authored Mar 4, 2024
1 parent c866be3 commit 1b3e4b0
Show file tree
Hide file tree
Showing 41 changed files with 343 additions and 358 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package fr.acinq.eclair.crypto

import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
import fr.acinq.eclair.EncodedNodeId
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.Attempt
Expand Down Expand Up @@ -341,14 +342,14 @@ object Sphinx extends Logging {
object RouteBlinding {

/**
* @param publicKey introduction node's public key (which cannot be blinded since the sender need to find a route to it).
* @param nodeId introduction node's id (which cannot be blinded since the sender need to find a route to it).
* @param blindedPublicKey blinded public key, which hides the real public key.
* @param blindingEphemeralKey blinding tweak that can be used by the receiving node to derive the private key that
* matches the blinded public key.
* @param encryptedPayload encrypted payload that can be decrypted with the introduction node's private key and the
* blinding ephemeral key.
*/
case class IntroductionNode(publicKey: PublicKey, blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector)
case class IntroductionNode(nodeId: EncodedNodeId, blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector)

/**
* @param blindedPublicKey blinded public key, which hides the real public key.
Expand All @@ -363,7 +364,7 @@ object Sphinx extends Logging {
* matches the blinded public key.
* @param blindedNodes blinded nodes (including the introduction node).
*/
case class BlindedRoute(introductionNodeId: PublicKey, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) {
case class BlindedRoute(introductionNodeId: EncodedNodeId, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) {
require(blindedNodes.nonEmpty, "blinded route must not be empty")
val introductionNode: IntroductionNode = IntroductionNode(introductionNodeId, blindedNodes.head.blindedPublicKey, blindingKey, blindedNodes.head.encryptedPayload)
val subsequentNodes: Seq[BlindedNode] = blindedNodes.tail
Expand Down Expand Up @@ -398,7 +399,7 @@ object Sphinx extends Logging {
e = e.multiply(PrivateKey(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes)))
(BlindedNode(blindedPublicKey, encryptedPayload ++ mac), blindingKey)
}.unzip
BlindedRouteDetails(BlindedRoute(publicKeys.head, blindingKeys.head, blindedHops), blindingKeys.last)
BlindedRouteDetails(BlindedRoute(EncodedNodeId(publicKeys.head), blindingKeys.head, blindedHops), blindingKeys.last)
}

/**
Expand Down
11 changes: 7 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ object MessageRelay {
case class Disconnected(messageId: ByteVector32) extends Failure {
override def toString: String = "Peer is not connected"
}
case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown outgoing channel: $outgoingChannelId"
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown channel: $channelId"
}
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
override def toString: String = s"Message dropped: $reason"
Expand Down Expand Up @@ -99,6 +99,8 @@ private class MessageRelay(nodeParams: NodeParams,

def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
nextNode match {
case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf =>
withNextNodeId(msg, nodeParams.nodeId)
case Left(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(msg, outgoingChannelId)
Expand All @@ -110,14 +112,15 @@ private class MessageRelay(nodeParams: NodeParams,
}
}

private def waitForNextNodeId(msg: OnionMessage, outgoingChannelId: ShortChannelId): Behavior[Command] =
private def waitForNextNodeId(msg: OnionMessage, channelId: ShortChannelId): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedOptionalNodeId(None) =>
replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId))
replyTo_opt.foreach(_ ! UnknownChannel(messageId, channelId))
Behaviors.stopped
case WrappedOptionalNodeId(Some(nextNodeId)) =>
withNextNodeId(msg, nextNodeId)
}
}

private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
if (nextNodeId == nodeParams.nodeId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => {
case Some(hop: NodeHop) if channelNodeIds.nonEmpty => Seq(hop.nextNodeId)
case Some(hop: NodeHop) => Seq(hop.nodeId, hop.nextNodeId)
case Some(hop: BlindedHop) if channelNodeIds.nonEmpty => hop.route.blindedNodeIds.tail
case Some(hop: BlindedHop) => hop.route.introductionNodeId +: hop.route.blindedNodeIds.tail
case Some(hop: BlindedHop) => hop.nodeId +: hop.route.blindedNodeIds.tail
case None => Nil
}
RouteNodeIdsJson(route.amount, channelNodeIds ++ finalNodeIds)
Expand Down Expand Up @@ -468,14 +468,8 @@ object InvoiceSerializer extends MinimalSerializer({
UnknownFeatureSerializer
)),
JField("blindedPaths", JArray(p.blindedPaths.map(path => {
val introductionNode = path.route match {
case OfferTypes.BlindedPath(route) => route.introductionNodeId.toString
case OfferTypes.CompactBlindedPath(shortIdDir, _, _) => s"${if (shortIdDir.isNode1) '0' else '1'}x${shortIdDir.scid.toString}"
}
val blindedNodes = path.route match {
case OfferTypes.BlindedPath(route) => route.blindedNodes
case OfferTypes.CompactBlindedPath(_, _, nodes) => nodes
}
val introductionNode = path.route.introductionNodeId.toString
val blindedNodes = path.route.blindedNodes
JObject(List(
JField("introductionNodeId", JString(introductionNode)),
JField("blindedNodeIds", JArray(blindedNodes.map(n => JString(n.blindedPublicKey.toString)).toList))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import fr.acinq.eclair.wire.protocol._
import scodec.bits.ByteVector
import scodec.{Attempt, DecodeResult}

import scala.annotation.tailrec
import scala.concurrent.duration.FiniteDuration

object OnionMessages {
Expand All @@ -44,23 +43,29 @@ object OnionMessages {
timeout: FiniteDuration,
maxAttempts: Int)

case class IntermediateNode(nodeId: PublicKey, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) {
def toTlvStream(nextNodeId: PublicKey, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] =
case class IntermediateNode(publicKey: PublicKey, encodedNodeId: EncodedNodeId, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) {
def toTlvStream(nextNodeId: EncodedNodeId, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] =
TlvStream(Set[Option[RouteBlindingEncryptedDataTlv]](
padding.map(Padding),
outgoingChannel_opt.map(OutgoingChannelId).orElse(Some(OutgoingNodeId(nextNodeId))),
nextBlinding_opt.map(NextBlinding)
).flatten, customTlvs)
}

object IntermediateNode {
def apply(publicKey: PublicKey): IntermediateNode = IntermediateNode(publicKey, EncodedNodeId(publicKey))
}

// @formatter:off
sealed trait Destination {
def nodeId: PublicKey
def introductionNodeId: EncodedNodeId
}
case class BlindedPath(route: Sphinx.RouteBlinding.BlindedRoute) extends Destination {
override def nodeId: PublicKey = route.introductionNodeId
override def introductionNodeId: EncodedNodeId = route.introductionNodeId
}
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination {
override def introductionNodeId: EncodedNodeId = EncodedNodeId(nodeId)
}
case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination
// @formatter:on

// @formatter:off
Expand All @@ -75,11 +80,11 @@ object OnionMessages {
}
// @formatter:on

private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: PublicKey, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = {
private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: EncodedNodeId, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = {
if (intermediateNodes.isEmpty) {
Nil
} else {
val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.nodeId) }
val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.encodedNodeId) }
val lastPayload = intermediateNodes.last.toTlvStream(lastNodeId, lastBlinding_opt)
(intermediatePayloads :+ lastPayload).map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes)
}
Expand All @@ -88,33 +93,22 @@ object OnionMessages {
def buildRoute(blindingSecret: PrivateKey,
intermediateNodes: Seq[IntermediateNode],
recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = {
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, recipient.nodeId)
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, EncodedNodeId(recipient.nodeId))
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route
}

private[message] def buildRouteFrom(originKey: PrivateKey,
blindingSecret: PrivateKey,
private[message] def buildRouteFrom(blindingSecret: PrivateKey,
intermediateNodes: Seq[IntermediateNode],
destination: Destination): Option[Sphinx.RouteBlinding.BlindedRoute] = {
destination: Destination): Sphinx.RouteBlinding.BlindedRoute = {
destination match {
case recipient: Recipient => Some(buildRoute(blindingSecret, intermediateNodes, recipient))
case BlindedPath(route) if route.introductionNodeId == originKey.publicKey =>
RouteBlindingEncryptedDataCodecs.decode(originKey, route.blindingKey, route.blindedNodes.head.encryptedPayload) match {
case Left(_) => None
case Right(decoded) =>
decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match {
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.Plain(nextNodeId))) =>
Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail))
case _ => None // TODO: allow compact node id and OutgoingChannelId
}
}
case BlindedPath(route) if intermediateNodes.isEmpty => Some(route)
case recipient: Recipient => buildRoute(blindingSecret, intermediateNodes, recipient)
case BlindedPath(route) if intermediateNodes.isEmpty => route
case BlindedPath(route) =>
val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, route.introductionNodeId, Some(route.blindingKey))
val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route
Some(Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes))
val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey), intermediatePayloads).route
Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes)
}
}

Expand All @@ -134,32 +128,28 @@ object OnionMessages {
* @param content List of TLVs to send to the recipient of the message
* @return The node id to send the onion to and the onion containing the message
*/
def buildMessage(nodeKey: PrivateKey,
sessionKey: PrivateKey,
def buildMessage(sessionKey: PrivateKey,
blindingSecret: PrivateKey,
intermediateNodes: Seq[IntermediateNode],
destination: Destination,
content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, (PublicKey, OnionMessage)] = {
buildRouteFrom(nodeKey, blindingSecret, intermediateNodes, destination) match {
case None => Left(InvalidDestination(destination))
case Some(route) =>
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
val packetSize = if (payloadSize <= 1300) {
1300
} else if (payloadSize <= 32768) {
32768
} else if (payloadSize > 65432) {
// A payload of size 65432 corresponds to a total lightning message size of 65535.
return Left(MessageTooLarge(payloadSize))
} else {
payloadSize.toInt
}
// Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`).
val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get
Right((route.introductionNodeId, OnionMessage(route.blindingKey, packet)))
content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, OnionMessage] = {
val route = buildRouteFrom(blindingSecret, intermediateNodes, destination)
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
val packetSize = if (payloadSize <= 1300) {
1300
} else if (payloadSize <= 32768) {
32768
} else if (payloadSize > 65432) {
// A payload of size 65432 corresponds to a total lightning message size of 65535.
return Left(MessageTooLarge(payloadSize))
} else {
payloadSize.toInt
}
// Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`).
val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get
Right(OnionMessage(route.blindingKey, packet))
}

// @formatter:off
Expand Down Expand Up @@ -199,7 +189,6 @@ object OnionMessages {
}
}

@tailrec
def process(privateKey: PrivateKey, msg: OnionMessage): Action = {
val blindedPrivateKey = Sphinx.RouteBlinding.derivePrivateKey(privateKey, msg.blindingKey)
decryptOnion(blindedPrivateKey, msg.onionRoutingPacket) match {
Expand All @@ -210,11 +199,7 @@ object OnionMessages {
decryptEncryptedData(privateKey, msg.blindingKey, encryptedData) match {
case Left(f) => DropMessage(f)
case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match {
case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match {
case SendMessage(Right(EncodedNodeId.Plain(publicKey)), nextMsg) if publicKey == privateKey.publicKey => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
case action => action
}
case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket)
case None => validateFinalPayload(payload, blindedPayload)
}
}
Expand Down
Loading

0 comments on commit 1b3e4b0

Please sign in to comment.