From 1b3e4b0daee337ef3ab918d24bec43d4365625ee Mon Sep 17 00:00:00 2001 From: Thomas HUET <81159533+thomash-acinq@users.noreply.github.com> Date: Mon, 4 Mar 2024 15:44:55 +0100 Subject: [PATCH] Allow relaying messages to self (#2834) Allow sending messages to self Fixes corner cases caused by compact encoding of node ids. Every message to be relayed now follows the same path and `MessageRelay` can relay to self. --- .../scala/fr/acinq/eclair/crypto/Sphinx.scala | 9 +- .../fr/acinq/eclair/io/MessageRelay.scala | 11 ++- .../acinq/eclair/json/JsonSerializers.scala | 12 +-- .../acinq/eclair/message/OnionMessages.scala | 95 ++++++++----------- .../fr/acinq/eclair/message/Postman.scala | 47 ++++----- .../acinq/eclair/payment/Bolt12Invoice.scala | 7 +- .../payment/receive/MultiPartHandler.scala | 15 +-- .../eclair/payment/relay/NodeRelay.scala | 8 +- .../send/CompactBlindedPathsResolver.scala | 34 +++---- .../eclair/payment/send/OfferPayment.scala | 23 ++--- .../payment/send/PaymentInitiator.scala | 6 +- .../acinq/eclair/payment/send/Recipient.scala | 19 ++-- .../scala/fr/acinq/eclair/router/Router.scala | 4 +- .../eclair/wire/protocol/MessageOnion.scala | 10 +- .../eclair/wire/protocol/OfferCodecs.scala | 32 +++---- .../eclair/wire/protocol/OfferTypes.scala | 15 ++- .../eclair/wire/protocol/PaymentOnion.scala | 14 +-- .../fr/acinq/eclair/crypto/SphinxSpec.scala | 4 +- .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 2 +- .../integration/MessageIntegrationSpec.scala | 5 +- .../integration/PaymentIntegrationSpec.scala | 21 ++-- .../basic/payment/OfferPaymentSpec.scala | 10 +- .../fr/acinq/eclair/io/MessageRelaySpec.scala | 51 ++++++---- .../acinq/eclair/io/PeerConnectionSpec.scala | 10 +- .../scala/fr/acinq/eclair/io/PeerSpec.scala | 4 +- .../eclair/message/OnionMessagesSpec.scala | 38 ++++---- .../fr/acinq/eclair/message/PostmanSpec.scala | 46 +++++---- .../eclair/payment/Bolt12InvoiceSpec.scala | 4 +- .../eclair/payment/MultiPartHandlerSpec.scala | 16 ++-- .../eclair/payment/PaymentInitiatorSpec.scala | 15 +-- .../eclair/payment/PaymentPacketSpec.scala | 17 ++-- .../payment/offer/OfferManagerSpec.scala | 2 +- .../payment/relay/NodeRelayerSpec.scala | 12 +-- .../payment/send/OfferPaymentSpec.scala | 28 +++--- .../acinq/eclair/router/BaseRouterSpec.scala | 11 ++- .../router/BlindedRouteCreationSpec.scala | 6 +- .../protocol/MessageOnionCodecsSpec.scala | 10 +- .../eclair/wire/protocol/OfferTypesSpec.scala | 15 ++- .../wire/protocol/PaymentOnionSpec.scala | 9 +- .../api/serde/FormParamExtractors.scala | 2 +- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 2 +- 41 files changed, 343 insertions(+), 358 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index e4c88a7747..ce426bdd62 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} +import fr.acinq.eclair.EncodedNodeId import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging import scodec.Attempt @@ -341,14 +342,14 @@ object Sphinx extends Logging { object RouteBlinding { /** - * @param publicKey introduction node's public key (which cannot be blinded since the sender need to find a route to it). + * @param nodeId introduction node's id (which cannot be blinded since the sender need to find a route to it). * @param blindedPublicKey blinded public key, which hides the real public key. * @param blindingEphemeralKey blinding tweak that can be used by the receiving node to derive the private key that * matches the blinded public key. * @param encryptedPayload encrypted payload that can be decrypted with the introduction node's private key and the * blinding ephemeral key. */ - case class IntroductionNode(publicKey: PublicKey, blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector) + case class IntroductionNode(nodeId: EncodedNodeId, blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector) /** * @param blindedPublicKey blinded public key, which hides the real public key. @@ -363,7 +364,7 @@ object Sphinx extends Logging { * matches the blinded public key. * @param blindedNodes blinded nodes (including the introduction node). */ - case class BlindedRoute(introductionNodeId: PublicKey, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) { + case class BlindedRoute(introductionNodeId: EncodedNodeId, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) { require(blindedNodes.nonEmpty, "blinded route must not be empty") val introductionNode: IntroductionNode = IntroductionNode(introductionNodeId, blindedNodes.head.blindedPublicKey, blindingKey, blindedNodes.head.encryptedPayload) val subsequentNodes: Seq[BlindedNode] = blindedNodes.tail @@ -398,7 +399,7 @@ object Sphinx extends Logging { e = e.multiply(PrivateKey(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes))) (BlindedNode(blindedPublicKey, encryptedPayload ++ mac), blindingKey) }.unzip - BlindedRouteDetails(BlindedRoute(publicKeys.head, blindingKeys.head, blindedHops), blindingKeys.last) + BlindedRouteDetails(BlindedRoute(EncodedNodeId(publicKeys.head), blindingKeys.head, blindedHops), blindingKeys.last) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 52fe8f2821..3e7383f21f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -59,8 +59,8 @@ object MessageRelay { case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" } - case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure { - override def toString: String = s"Unknown outgoing channel: $outgoingChannelId" + case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { + override def toString: String = s"Unknown channel: $channelId" } case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" @@ -99,6 +99,8 @@ private class MessageRelay(nodeParams: NodeParams, def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { nextNode match { + case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf => + withNextNodeId(msg, nodeParams.nodeId) case Left(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) waitForNextNodeId(msg, outgoingChannelId) @@ -110,14 +112,15 @@ private class MessageRelay(nodeParams: NodeParams, } } - private def waitForNextNodeId(msg: OnionMessage, outgoingChannelId: ShortChannelId): Behavior[Command] = + private def waitForNextNodeId(msg: OnionMessage, channelId: ShortChannelId): Behavior[Command] = { Behaviors.receiveMessagePartial { case WrappedOptionalNodeId(None) => - replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId)) + replyTo_opt.foreach(_ ! UnknownChannel(messageId, channelId)) Behaviors.stopped case WrappedOptionalNodeId(Some(nextNodeId)) => withNextNodeId(msg, nextNodeId) } + } private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { if (nextNodeId == nodeParams.nodeId) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala index 9a0d6e4a2b..8e77a504cd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/json/JsonSerializers.scala @@ -357,7 +357,7 @@ object RouteNodeIdsSerializer extends ConvertClassSerializer[Route](route => { case Some(hop: NodeHop) if channelNodeIds.nonEmpty => Seq(hop.nextNodeId) case Some(hop: NodeHop) => Seq(hop.nodeId, hop.nextNodeId) case Some(hop: BlindedHop) if channelNodeIds.nonEmpty => hop.route.blindedNodeIds.tail - case Some(hop: BlindedHop) => hop.route.introductionNodeId +: hop.route.blindedNodeIds.tail + case Some(hop: BlindedHop) => hop.nodeId +: hop.route.blindedNodeIds.tail case None => Nil } RouteNodeIdsJson(route.amount, channelNodeIds ++ finalNodeIds) @@ -468,14 +468,8 @@ object InvoiceSerializer extends MinimalSerializer({ UnknownFeatureSerializer )), JField("blindedPaths", JArray(p.blindedPaths.map(path => { - val introductionNode = path.route match { - case OfferTypes.BlindedPath(route) => route.introductionNodeId.toString - case OfferTypes.CompactBlindedPath(shortIdDir, _, _) => s"${if (shortIdDir.isNode1) '0' else '1'}x${shortIdDir.scid.toString}" - } - val blindedNodes = path.route match { - case OfferTypes.BlindedPath(route) => route.blindedNodes - case OfferTypes.CompactBlindedPath(_, _, nodes) => nodes - } + val introductionNode = path.route.introductionNodeId.toString + val blindedNodes = path.route.blindedNodes JObject(List( JField("introductionNodeId", JString(introductionNode)), JField("blindedNodeIds", JArray(blindedNodes.map(n => JString(n.blindedPublicKey.toString)).toList)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala index 05f599c13b..8364a76352 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala @@ -27,7 +27,6 @@ import fr.acinq.eclair.wire.protocol._ import scodec.bits.ByteVector import scodec.{Attempt, DecodeResult} -import scala.annotation.tailrec import scala.concurrent.duration.FiniteDuration object OnionMessages { @@ -44,8 +43,8 @@ object OnionMessages { timeout: FiniteDuration, maxAttempts: Int) - case class IntermediateNode(nodeId: PublicKey, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) { - def toTlvStream(nextNodeId: PublicKey, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] = + case class IntermediateNode(publicKey: PublicKey, encodedNodeId: EncodedNodeId, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) { + def toTlvStream(nextNodeId: EncodedNodeId, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] = TlvStream(Set[Option[RouteBlindingEncryptedDataTlv]]( padding.map(Padding), outgoingChannel_opt.map(OutgoingChannelId).orElse(Some(OutgoingNodeId(nextNodeId))), @@ -53,14 +52,20 @@ object OnionMessages { ).flatten, customTlvs) } + object IntermediateNode { + def apply(publicKey: PublicKey): IntermediateNode = IntermediateNode(publicKey, EncodedNodeId(publicKey)) + } + // @formatter:off sealed trait Destination { - def nodeId: PublicKey + def introductionNodeId: EncodedNodeId } case class BlindedPath(route: Sphinx.RouteBlinding.BlindedRoute) extends Destination { - override def nodeId: PublicKey = route.introductionNodeId + override def introductionNodeId: EncodedNodeId = route.introductionNodeId + } + case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination { + override def introductionNodeId: EncodedNodeId = EncodedNodeId(nodeId) } - case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination // @formatter:on // @formatter:off @@ -75,11 +80,11 @@ object OnionMessages { } // @formatter:on - private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: PublicKey, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = { + private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: EncodedNodeId, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = { if (intermediateNodes.isEmpty) { Nil } else { - val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.nodeId) } + val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.encodedNodeId) } val lastPayload = intermediateNodes.last.toTlvStream(lastNodeId, lastBlinding_opt) (intermediatePayloads :+ lastPayload).map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes) } @@ -88,33 +93,22 @@ object OnionMessages { def buildRoute(blindingSecret: PrivateKey, intermediateNodes: Seq[IntermediateNode], recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = { - val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, recipient.nodeId) + val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, EncodedNodeId(recipient.nodeId)) val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes - Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route + Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route } - private[message] def buildRouteFrom(originKey: PrivateKey, - blindingSecret: PrivateKey, + private[message] def buildRouteFrom(blindingSecret: PrivateKey, intermediateNodes: Seq[IntermediateNode], - destination: Destination): Option[Sphinx.RouteBlinding.BlindedRoute] = { + destination: Destination): Sphinx.RouteBlinding.BlindedRoute = { destination match { - case recipient: Recipient => Some(buildRoute(blindingSecret, intermediateNodes, recipient)) - case BlindedPath(route) if route.introductionNodeId == originKey.publicKey => - RouteBlindingEncryptedDataCodecs.decode(originKey, route.blindingKey, route.blindedNodes.head.encryptedPayload) match { - case Left(_) => None - case Right(decoded) => - decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match { - case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.Plain(nextNodeId))) => - Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail)) - case _ => None // TODO: allow compact node id and OutgoingChannelId - } - } - case BlindedPath(route) if intermediateNodes.isEmpty => Some(route) + case recipient: Recipient => buildRoute(blindingSecret, intermediateNodes, recipient) + case BlindedPath(route) if intermediateNodes.isEmpty => route case BlindedPath(route) => val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, route.introductionNodeId, Some(route.blindingKey)) - val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route - Some(Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes)) + val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.publicKey), intermediatePayloads).route + Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes) } } @@ -134,32 +128,28 @@ object OnionMessages { * @param content List of TLVs to send to the recipient of the message * @return The node id to send the onion to and the onion containing the message */ - def buildMessage(nodeKey: PrivateKey, - sessionKey: PrivateKey, + def buildMessage(sessionKey: PrivateKey, blindingSecret: PrivateKey, intermediateNodes: Seq[IntermediateNode], destination: Destination, - content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, (PublicKey, OnionMessage)] = { - buildRouteFrom(nodeKey, blindingSecret, intermediateNodes, destination) match { - case None => Left(InvalidDestination(destination)) - case Some(route) => - val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes - val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload - val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum - val packetSize = if (payloadSize <= 1300) { - 1300 - } else if (payloadSize <= 32768) { - 32768 - } else if (payloadSize > 65432) { - // A payload of size 65432 corresponds to a total lightning message size of 65535. - return Left(MessageTooLarge(payloadSize)) - } else { - payloadSize.toInt - } - // Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`). - val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get - Right((route.introductionNodeId, OnionMessage(route.blindingKey, packet))) + content: TlvStream[OnionMessagePayloadTlv]): Either[BuildMessageError, OnionMessage] = { + val route = buildRouteFrom(blindingSecret, intermediateNodes, destination) + val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes + val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload + val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum + val packetSize = if (payloadSize <= 1300) { + 1300 + } else if (payloadSize <= 32768) { + 32768 + } else if (payloadSize > 65432) { + // A payload of size 65432 corresponds to a total lightning message size of 65535. + return Left(MessageTooLarge(payloadSize)) + } else { + payloadSize.toInt } + // Since we are setting the packet size based on the payload, the onion creation should never fail (hence the `.get`). + val Sphinx.PacketAndSecrets(packet, _) = Sphinx.create(sessionKey, packetSize, route.blindedNodes.map(_.blindedPublicKey), payloads, None).get + Right(OnionMessage(route.blindingKey, packet)) } // @formatter:off @@ -199,7 +189,6 @@ object OnionMessages { } } - @tailrec def process(privateKey: PrivateKey, msg: OnionMessage): Action = { val blindedPrivateKey = Sphinx.RouteBlinding.derivePrivateKey(privateKey, msg.blindingKey) decryptOnion(blindedPrivateKey, msg.onionRoutingPacket) match { @@ -210,11 +199,7 @@ object OnionMessages { decryptEncryptedData(privateKey, msg.blindingKey, encryptedData) match { case Left(f) => DropMessage(f) case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match { - case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match { - case SendMessage(Right(EncodedNodeId.Plain(publicKey)), nextMsg) if publicKey == privateKey.publicKey => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay - case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay - case action => action - } + case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) case None => validateFinalPayload(payload, blindedPayload) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala index e25f73fa79..68087932ef 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala @@ -29,7 +29,7 @@ import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, MessageRouteResponse} import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload} -import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, ContactInfo} +import fr.acinq.eclair.wire.protocol.OfferTypes.ContactInfo import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.{EncodedNodeId, NodeParams, randomBytes32, randomKey} @@ -149,42 +149,43 @@ private class SendingMessage(nodeParams: NodeParams, Behaviors.receiveMessagePartial { case SendMessage => contactInfo match { - case compact: OfferTypes.CompactBlindedPath => - router ! Router.GetNodeId(context.messageAdapter(WrappedNodeIdResponse), compact.introductionNode.scid, compact.introductionNode.isNode1) - waitForNodeId(compact) - case OfferTypes.BlindedPath(route) => sendToDestination(OnionMessages.BlindedPath(route)) - case OfferTypes.RecipientNodeId(nodeId) => sendToDestination(OnionMessages.Recipient(nodeId, None)) + case OfferTypes.BlindedPath(route@BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1, scid), _, _)) => + router ! Router.GetNodeId(context.messageAdapter(WrappedNodeIdResponse), scid, isNode1) + waitForNodeId(route) + case OfferTypes.BlindedPath(route@BlindedRoute(EncodedNodeId.Plain(publicKey), _, _)) => sendToDestination(OnionMessages.BlindedPath(route), publicKey) + case OfferTypes.RecipientNodeId(nodeId) => sendToDestination(OnionMessages.Recipient(nodeId, None), nodeId) } } } - private def waitForNodeId(compactBlindedPath: CompactBlindedPath): Behavior[Command] = { + private def waitForNodeId(compactBlindedPath: BlindedRoute): Behavior[Command] = { Behaviors.receiveMessagePartial { case WrappedNodeIdResponse(None) => - replyTo ! Postman.MessageFailed(s"Could not resolve introduction node for compact blinded path (scid=${compactBlindedPath.introductionNode.scid.toCoordinatesString})") + replyTo ! Postman.MessageFailed(s"Could not resolve introduction node for compact blinded path: ${compactBlindedPath.introductionNode.nodeId}") Behaviors.stopped case WrappedNodeIdResponse(Some(nodeId)) => - sendToDestination(OnionMessages.BlindedPath(BlindedRoute(nodeId, compactBlindedPath.blindingKey, compactBlindedPath.blindedNodes))) + sendToDestination(OnionMessages.BlindedPath(compactBlindedPath), nodeId) } } - private def sendToDestination(destination: Destination): Behavior[Command] = { + private def sendToDestination(destination: Destination, plainNodeId: PublicKey): Behavior[Command] = { routingStrategy match { - case RoutingStrategy.UseRoute(intermediateNodes) => sendToRoute(intermediateNodes, destination) - case RoutingStrategy.FindRoute if destination.nodeId == nodeParams.nodeId => - context.self ! WrappedMessageRouteResponse(MessageRoute(Nil, destination.nodeId)) - waitForRouteFromRouter(destination) + case RoutingStrategy.UseRoute(intermediateNodes) => + sendToRoute(intermediateNodes, destination, plainNodeId) + case RoutingStrategy.FindRoute if plainNodeId == nodeParams.nodeId => + context.self ! WrappedMessageRouteResponse(MessageRoute(Nil, plainNodeId)) + waitForRouteFromRouter(destination, plainNodeId) case RoutingStrategy.FindRoute => - router ! Router.MessageRouteRequest(context.messageAdapter(WrappedMessageRouteResponse), nodeParams.nodeId, destination.nodeId, Set.empty) - waitForRouteFromRouter(destination) + router ! Router.MessageRouteRequest(context.messageAdapter(WrappedMessageRouteResponse), nodeParams.nodeId, plainNodeId, Set.empty) + waitForRouteFromRouter(destination, plainNodeId) } } - private def waitForRouteFromRouter(destination: Destination): Behavior[Command] = { + private def waitForRouteFromRouter(destination: Destination, plainNodeId: PublicKey): Behavior[Command] = { Behaviors.receiveMessagePartial { case WrappedMessageRouteResponse(MessageRoute(intermediateNodes, targetNodeId)) => context.log.debug("Found route: {}", (intermediateNodes :+ targetNodeId).mkString(" -> ")) - sendToRoute(intermediateNodes, destination) + sendToRoute(intermediateNodes, destination, plainNodeId) case WrappedMessageRouteResponse(MessageRouteNotFound(targetNodeId)) => context.log.debug("No route found to {}", targetNodeId) replyTo ! Postman.MessageFailed("No route found") @@ -192,19 +193,18 @@ private class SendingMessage(nodeParams: NodeParams, } } - private def sendToRoute(intermediateNodes: Seq[PublicKey], destination: Destination): Behavior[Command] = { + private def sendToRoute(intermediateNodes: Seq[PublicKey], destination: Destination, plainNodeId: PublicKey): Behavior[Command] = { val messageId = randomBytes32() val replyRoute = if (expectsReply) { val numHopsToAdd = 0.max(nodeParams.onionMessageConfig.minIntermediateHops - intermediateNodes.length - 1) - val intermediateHops = (Seq(destination.nodeId) ++ intermediateNodes.reverse ++ Seq.fill(numHopsToAdd)(nodeParams.nodeId)).map(OnionMessages.IntermediateNode(_)) + val intermediateHops = OnionMessages.IntermediateNode(plainNodeId, destination.introductionNodeId) +: (intermediateNodes.reverse ++ Seq.fill(numHopsToAdd)(nodeParams.nodeId)).map(OnionMessages.IntermediateNode(_)) val lastHop = OnionMessages.Recipient(nodeParams.nodeId, Some(messageId)) Some(OnionMessages.buildRoute(randomKey(), intermediateHops, lastHop)) } else { None } OnionMessages.buildMessage( - nodeParams.privateKey, randomKey(), randomKey(), intermediateNodes.map(OnionMessages.IntermediateNode(_)), @@ -213,9 +213,10 @@ private class SendingMessage(nodeParams: NodeParams, case Left(failure) => replyTo ! Postman.MessageFailed(failure.toString) Behaviors.stopped - case Right((nextNodeId, message)) => + case Right(message) => + val nextNodeId = EncodedNodeId(intermediateNodes.headOption.getOrElse(plainNodeId)) val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") - relay ! MessageRelay.RelayMessage(messageId, nodeParams.nodeId, Right(EncodedNodeId(nextNodeId)), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) + relay ! MessageRelay.RelayMessage(messageId, nodeParams.nodeId, Right(nextNodeId), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) waitForSent() } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala index e65811560c..bcb0e34c68 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Bolt12Invoice.scala @@ -19,7 +19,6 @@ package fr.acinq.eclair.payment import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{BlockHash, ByteVector32, ByteVector64, Crypto} -import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv} @@ -53,7 +52,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { // We add invoice features that are implicitly required for Bolt 12 (the spec doesn't allow explicitly setting them). f.add(Features.VariableLengthOnion, FeatureSupport.Mandatory).add(Features.RouteBlinding, FeatureSupport.Mandatory) } - val blindedPaths: Seq[PaymentBlindedContactInfo] = records.get[InvoicePaths].get.paths.zip(records.get[InvoiceBlindedPay].get.paymentInfo).map { case (route, info) => PaymentBlindedContactInfo(route, info) } + val blindedPaths: Seq[PaymentBlindedRoute] = records.get[InvoicePaths].get.paths.zip(records.get[InvoiceBlindedPay].get.paymentInfo).map { case (route, info) => PaymentBlindedRoute(route, info) } val fallbacks: Option[Seq[FallbackAddress]] = records.get[InvoiceFallbacks].map(_.addresses) val signature: ByteVector64 = records.get[Signature].get.signature @@ -87,8 +86,6 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice { } -case class PaymentBlindedContactInfo(route: BlindedContactInfo, paymentInfo: PaymentInfo) - case class PaymentBlindedRoute(route: BlindedRoute, paymentInfo: PaymentInfo) object Bolt12Invoice { @@ -110,7 +107,7 @@ object Bolt12Invoice { nodeKey: PrivateKey, invoiceExpiry: FiniteDuration, features: Features[Bolt12Feature], - paths: Seq[PaymentBlindedContactInfo], + paths: Seq[PaymentBlindedRoute], additionalTlvs: Set[InvoiceTlv] = Set.empty, customTlvs: Set[GenericTlv] = Set.empty): Bolt12Invoice = { require(request.amount.nonEmpty || request.offer.amount.nonEmpty) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 9a0f730f66..769a00767a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -26,9 +26,10 @@ import akka.pattern.ask import akka.util.Timeout import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} -import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir +import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, RES_SUCCESS} +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.db._ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} @@ -370,22 +371,22 @@ object MultiPartHandler { createBlindedRouteFromHops(dummyHops, r.pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) } val contactInfo = route.shortChannelIdDir_opt match { - case Some(shortChannelIdDir) => OfferTypes.CompactBlindedPath(shortChannelIdDir, blindedRoute.route.blindingKey, blindedRoute.route.blindedNodes) - case None => OfferTypes.BlindedPath(blindedRoute.route) + case Some(shortChannelIdDir) => BlindedRoute(shortChannelIdDir, blindedRoute.route.blindingKey, blindedRoute.route.blindedNodes) + case None => blindedRoute.route } val paymentInfo = aggregatePaymentInfo(r.amount, dummyHops, nodeParams.channelConf.minFinalExpiryDelta) - Future.successful(PaymentBlindedContactInfo(contactInfo, paymentInfo)) + Future.successful(PaymentBlindedRoute(contactInfo, paymentInfo)) } else { implicit val timeout: Timeout = 10.seconds r.router.ask(Router.FinalizeRoute(Router.PredefinedNodeRoute(r.amount, route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => { val clearRoute = routeResponse.routes.head val blindedRoute = createBlindedRouteFromHops(clearRoute.hops ++ dummyHops, r.pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) val contactInfo = route.shortChannelIdDir_opt match { - case Some(shortChannelIdDir) => OfferTypes.CompactBlindedPath(shortChannelIdDir, blindedRoute.route.blindingKey, blindedRoute.route.blindedNodes) - case None => OfferTypes.BlindedPath(blindedRoute.route) + case Some(shortChannelIdDir) => BlindedRoute(shortChannelIdDir, blindedRoute.route.blindingKey, blindedRoute.route.blindedNodes) + case None => blindedRoute.route } val paymentInfo = aggregatePaymentInfo(r.amount, clearRoute.hops ++ dummyHops, nodeParams.channelConf.minFinalExpiryDelta) - PaymentBlindedContactInfo(contactInfo, paymentInfo) + PaymentBlindedRoute(contactInfo, paymentInfo) }) } })).map(paths => { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 7db8183b8f..d81bf5dcaf 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart -import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.Resolve +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.{Resolve, ResolvedPath} import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode @@ -63,7 +63,7 @@ object NodeRelay { private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command private[relay] case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command - private case class WrappedResolvedPaths(resolved: Seq[PaymentBlindedRoute]) extends Command + private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command // @formatter:on trait OutgoingPaymentFactory { @@ -340,7 +340,7 @@ class NodeRelay private(nodeParams: NodeParams, relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = true) } case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths => - context.spawnAnonymous(CompactBlindedPathsResolver(router)) ! Resolve(context.messageAdapter[Seq[PaymentBlindedRoute]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths) + context.spawnAnonymous(CompactBlindedPathsResolver(router)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths) waitForResolvedPaths(upstream, payloadOut, paymentCfg, routeParams) } } @@ -378,7 +378,7 @@ class NodeRelay private(nodeParams: NodeParams, case WrappedResolvedPaths(resolved) => val features = Features(payloadOut.invoiceFeatures).invoiceFeatures() // We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient. - val blindedNodeId = resolved.head.route.blindedNodeIds.last + val blindedNodeId = resolved.head.blindedPath.route.blindedNodeIds.last val recipient = BlindedRecipient.fromPaths(blindedNodeId, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty) context.log.debug("sending the payment to blinded recipient, useMultiPart={}", features.hasFeature(Features.BasicMultiPartPayment)) relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/CompactBlindedPathsResolver.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/CompactBlindedPathsResolver.scala index efbfb88b22..bb4b95876e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/CompactBlindedPathsResolver.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/CompactBlindedPathsResolver.scala @@ -4,19 +4,21 @@ import akka.actor.typed.Behavior import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.EncodedNodeId import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute +import fr.acinq.eclair.payment.PaymentBlindedRoute import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver._ -import fr.acinq.eclair.payment.{PaymentBlindedContactInfo, PaymentBlindedRoute} import fr.acinq.eclair.router.Router -import fr.acinq.eclair.wire.protocol.OfferTypes.{BlindedPath, CompactBlindedPath, PaymentInfo} import scala.annotation.tailrec object CompactBlindedPathsResolver { // @formatter:off sealed trait Command - case class Resolve(replyTo: typed.ActorRef[Seq[PaymentBlindedRoute]], blindedPaths: Seq[PaymentBlindedContactInfo]) extends Command + case class Resolve(replyTo: typed.ActorRef[Seq[ResolvedPath]], blindedPaths: Seq[PaymentBlindedRoute]) extends Command private case class WrappedNodeId(nodeId_opt: Option[PublicKey]) extends Command + + case class ResolvedPath(blindedPath: PaymentBlindedRoute, introductionNodeId: PublicKey) // @formatter:on def apply(router: ActorRef): Behavior[Command] = { @@ -28,33 +30,31 @@ object CompactBlindedPathsResolver { } } -private class CompactBlindedPathsResolver(replyTo: typed.ActorRef[Seq[PaymentBlindedRoute]], +private class CompactBlindedPathsResolver(replyTo: typed.ActorRef[Seq[ResolvedPath]], router: ActorRef, context: ActorContext[Command]) { @tailrec - private def resolveCompactBlindedPaths(toResolve: Seq[PaymentBlindedContactInfo], - resolved: Seq[PaymentBlindedRoute]): Behavior[Command] = { + private def resolveCompactBlindedPaths(toResolve: Seq[PaymentBlindedRoute], + resolved: Seq[ResolvedPath]): Behavior[Command] = { toResolve.headOption match { - case Some(PaymentBlindedContactInfo(BlindedPath(route), paymentInfo)) => - resolveCompactBlindedPaths(toResolve.tail, resolved :+ PaymentBlindedRoute(route, paymentInfo)) - case Some(PaymentBlindedContactInfo(route: CompactBlindedPath, paymentInfo)) => - router ! Router.GetNodeId(context.messageAdapter(WrappedNodeId), route.introductionNode.scid, route.introductionNode.isNode1) - waitForNodeId(route, paymentInfo, toResolve.tail, resolved) + case Some(paymentRoute@PaymentBlindedRoute(BlindedRoute(EncodedNodeId.Plain(publicKey), _, _), _)) => + resolveCompactBlindedPaths(toResolve.tail, resolved :+ ResolvedPath(paymentRoute, publicKey)) + case Some(paymentRoute@PaymentBlindedRoute(BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1, scid), _, _), _)) => + router ! Router.GetNodeId(context.messageAdapter(WrappedNodeId), scid, isNode1) + waitForNodeId(paymentRoute, toResolve.tail, resolved) case None => replyTo ! resolved Behaviors.stopped } } - private def waitForNodeId(compactRoute: CompactBlindedPath, - paymentInfo: PaymentInfo, - toResolve: Seq[PaymentBlindedContactInfo], - resolved: Seq[PaymentBlindedRoute]): Behavior[Command] = + private def waitForNodeId(paymentRoute: PaymentBlindedRoute, + toResolve: Seq[PaymentBlindedRoute], + resolved: Seq[ResolvedPath]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedNodeId(None) => resolveCompactBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) => - val resolvedPaymentBlindedRoute = PaymentBlindedRoute(BlindedRoute(nodeId, compactRoute.blindingKey, compactRoute.blindedNodes), paymentInfo) - resolveCompactBlindedPaths(toResolve, resolved :+ resolvedPaymentBlindedRoute) + resolveCompactBlindedPaths(toResolve, resolved :+ ResolvedPath(paymentRoute, nodeId)) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala index c76dc530e7..6e181e96d5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala @@ -24,17 +24,14 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.message.Postman.{OnionMessageResponse, SendMessage} import fr.acinq.eclair.message.{OnionMessages, Postman} -import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.Resolve +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.{Resolve, ResolvedPath} import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendTrampolinePayment} -import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedContactInfo, PaymentBlindedRoute} -import fr.acinq.eclair.router.Router +import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoicePayload} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.{OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, NodeParams, RealShortChannelId, TimestampSecond, randomKey} - -import scala.annotation.tailrec +import fr.acinq.eclair.{CltvExpiryDelta, EncodedNodeId, Features, InvoiceFeature, MilliSatoshi, NodeParams, RealShortChannelId, TimestampSecond, randomKey} object OfferPayment { sealed trait Failure @@ -69,7 +66,7 @@ object OfferPayment { case class WrappedMessageResponse(response: OnionMessageResponse) extends Command - private case class WrappedResolvedPaths(resolved: Seq[PaymentBlindedRoute]) extends Command + private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command case class SendPaymentConfig(externalId_opt: Option[String], connectDirectly: Boolean, @@ -140,7 +137,7 @@ private class OfferPayment(replyTo: ActorRef, paymentInitiator ! SendTrampolinePayment(replyTo, payload.invoice.amount, payload.invoice, trampoline.nodeId, trampoline.attempts, sendPaymentConfig.routeParams) Behaviors.stopped case None => - context.spawnAnonymous(CompactBlindedPathsResolver(router)) ! Resolve(context.messageAdapter[Seq[PaymentBlindedRoute]](WrappedResolvedPaths), payload.invoice.blindedPaths) + context.spawnAnonymous(CompactBlindedPathsResolver(router)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payload.invoice.blindedPaths) waitForResolvedPaths(payload.invoice) } case WrappedMessageResponse(Postman.Response(payload)) => @@ -163,13 +160,13 @@ private class OfferPayment(replyTo: ActorRef, */ private def waitForResolvedPaths(invoice: Bolt12Invoice): Behavior[Command] = Behaviors.receiveMessagePartial { - case WrappedResolvedPaths(resolved) if resolved.isEmpty => + case WrappedResolvedPaths(resolved) if resolved.isEmpty => // We couldn't identify any of the blinded paths' introduction nodes because the scids are unknown. - val scids = invoice.blindedPaths.collect { case PaymentBlindedContactInfo(CompactBlindedPath(scdidDir, _, _), _) => scdidDir.scid } + val scids = invoice.blindedPaths.collect { case PaymentBlindedRoute(BlindedRoute(EncodedNodeId.ShortChannelIdDir(_, scid), _, _), _) => scid } replyTo ! UnknownShortChannelIds(scids) - Behaviors.stopped - case WrappedResolvedPaths(resolved) => + Behaviors.stopped + case WrappedResolvedPaths(resolved) => paymentInitiator ! SendPaymentToNode(replyTo, invoice.amount, invoice, resolved, maxAttempts = sendPaymentConfig.maxAttempts, externalId = sendPaymentConfig.externalId_opt, routeParams = sendPaymentConfig.routeParams, payerKey_opt = Some(payerKey), blockUntilComplete = sendPaymentConfig.blocking) - Behaviors.stopped + Behaviors.stopped } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala index 5e57cef105..0d9ef88e4b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentInitiator.scala @@ -24,6 +24,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PaymentType import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.ResolvedPath import fr.acinq.eclair.payment.send.PaymentError._ import fr.acinq.eclair.router.RouteNotFound import fr.acinq.eclair.router.Router._ @@ -31,7 +32,6 @@ import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, NodeParams, randomBytes32} import java.util.UUID -import scala.util.{Failure, Success, Try} /** * Created by PM on 29/08/2016. @@ -299,7 +299,7 @@ object PaymentInitiator { case class SendPaymentToNode(replyTo: ActorRef, recipientAmount: MilliSatoshi, invoice: Invoice, - resolvedPaths: Seq[PaymentBlindedRoute], + resolvedPaths: Seq[ResolvedPath], maxAttempts: Int, externalId: Option[String] = None, routeParams: RouteParams, @@ -363,7 +363,7 @@ object PaymentInitiator { */ case class SendPaymentToRoute(recipientAmount: MilliSatoshi, invoice: Invoice, - resolvedPaths: Seq[PaymentBlindedRoute], + resolvedPaths: Seq[ResolvedPath], route: PredefinedRoute, externalId: Option[String], parentId: Option[UUID], diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala index e597f34fc5..ffe0f126f0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/Recipient.scala @@ -21,7 +21,8 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment.OutgoingPaymentPacket._ -import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, Invoice, OutgoingPaymentPacket, PaymentBlindedRoute} +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.ResolvedPath +import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, Invoice, OutgoingPaymentPacket} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload, OutgoingBlindedPerHopPayload} import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionRoutingPacket} @@ -126,13 +127,13 @@ case class BlindedRecipient(nodeId: PublicKey, require(blindedHops.nonEmpty, "blinded routes must be provided") override val extraEdges = blindedHops.map { h => - ExtraEdge(h.route.introductionNodeId, nodeId, h.dummyId, h.paymentInfo.feeBase, h.paymentInfo.feeProportionalMillionths, h.paymentInfo.cltvExpiryDelta, h.paymentInfo.minHtlc, Some(h.paymentInfo.maxHtlc)) + ExtraEdge(h.nodeId, nodeId, h.dummyId, h.paymentInfo.feeBase, h.paymentInfo.feeProportionalMillionths, h.paymentInfo.cltvExpiryDelta, h.paymentInfo.minHtlc, Some(h.paymentInfo.maxHtlc)) } private def validateRoute(route: Route): Either[OutgoingPaymentError, BlindedHop] = { route.finalHop_opt match { case Some(blindedHop: BlindedHop) => Right(blindedHop) - case _ => Left(MissingBlindedHop(blindedHops.map(_.route.introductionNodeId).toSet)) + case _ => Left(MissingBlindedHop(blindedHops.map(_.nodeId).toSet)) } } @@ -140,9 +141,9 @@ case class BlindedRecipient(nodeId: PublicKey, val blinding = blindedHop.route.introductionNode.blindingEphemeralKey val payloads = if (blindedHop.route.subsequentNodes.isEmpty) { // The recipient is also the introduction node. - Seq(NodePayload(blindedHop.route.introductionNodeId, OutgoingBlindedPerHopPayload.createFinalIntroductionPayload(amount, totalAmount, expiry, blinding, blindedHop.route.introductionNode.encryptedPayload, customTlvs))) + Seq(NodePayload(blindedHop.nodeId, OutgoingBlindedPerHopPayload.createFinalIntroductionPayload(amount, totalAmount, expiry, blinding, blindedHop.route.introductionNode.encryptedPayload, customTlvs))) } else { - val introductionPayload = NodePayload(blindedHop.route.introductionNodeId, OutgoingBlindedPerHopPayload.createIntroductionPayload(blindedHop.route.introductionNode.encryptedPayload, blinding)) + val introductionPayload = NodePayload(blindedHop.nodeId, OutgoingBlindedPerHopPayload.createIntroductionPayload(blindedHop.route.introductionNode.encryptedPayload, blinding)) val intermediatePayloads = blindedHop.route.subsequentNodes.dropRight(1).map(n => NodePayload(n.blindedPublicKey, OutgoingBlindedPerHopPayload.createIntermediatePayload(n.encryptedPayload))) val finalPayload = NodePayload(blindedHop.route.blindedNodes.last.blindedPublicKey, OutgoingBlindedPerHopPayload.createFinalPayload(amount, totalAmount, expiry, blindedHop.route.blindedNodes.last.encryptedPayload, customTlvs)) introductionPayload +: intermediatePayloads :+ finalPayload @@ -170,16 +171,16 @@ object BlindedRecipient { * @param invoice Bolt invoice. Paths from the invoice must be passed as `paths` with compact paths expanded to include the node id. * @param paths Payment paths to use to reach the recipient. */ - def apply(invoice: Bolt12Invoice, paths: Seq[PaymentBlindedRoute], totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Set[GenericTlv]): BlindedRecipient = + def apply(invoice: Bolt12Invoice, paths: Seq[ResolvedPath], totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Set[GenericTlv]): BlindedRecipient = BlindedRecipient.fromPaths(invoice.nodeId, invoice.features, totalAmount, expiry, paths, customTlvs) - def fromPaths(nodeId: PublicKey, features: Features[InvoiceFeature], totalAmount: MilliSatoshi, expiry: CltvExpiry, paths: Seq[PaymentBlindedRoute], customTlvs: Set[GenericTlv]): BlindedRecipient = { + def fromPaths(nodeId: PublicKey, features: Features[InvoiceFeature], totalAmount: MilliSatoshi, expiry: CltvExpiry, paths: Seq[ResolvedPath], customTlvs: Set[GenericTlv]): BlindedRecipient = { val blindedHops = paths.map( - path => { + resolved => { // We don't know the scids of channels inside the blinded route, but it's useful to have an ID to refer to a // given edge in the graph, so we create a dummy one for the duration of the payment attempt. val dummyId = ShortChannelId.generateLocalAlias() - BlindedHop(dummyId, path.route, path.paymentInfo) + BlindedHop(resolved.introductionNodeId, dummyId, resolved.blindedPath.route, resolved.blindedPath.paymentInfo) }) BlindedRecipient(nodeId, features, totalAmount, expiry, blindedHops, customTlvs) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 9077c3e39e..31fbd2df21 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -523,14 +523,14 @@ object Router { * A directed hop over a blinded route composed of multiple (blinded) channels. * Since a blinded route has to be used from start to end, we model it as a single virtual hop. * + * @param nodeId introduction node id * @param dummyId dummy identifier to allow indexing in maps: unlike normal scid aliases, this one doesn't exist * in our routing tables and should be used carefully. * @param route blinded route covered by that hop. * @param paymentInfo payment information about the blinded route. */ - case class BlindedHop(dummyId: Alias, route: BlindedRoute, paymentInfo: OfferTypes.PaymentInfo) extends FinalHop { + case class BlindedHop(nodeId: PublicKey, dummyId: Alias, route: BlindedRoute, paymentInfo: OfferTypes.PaymentInfo) extends FinalHop { // @formatter:off - override val nodeId = route.introductionNodeId override val nextNodeId = route.blindedNodes.last.blindedPublicKey override val cltvExpiryDelta = paymentInfo.cltvExpiryDelta override def fee(amount: MilliSatoshi): MilliSatoshi = paymentInfo.fee(amount) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala index 021ae6c3b4..368f3a8f53 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala @@ -17,11 +17,11 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.{EncodedNodeId, ShortChannelId, UInt64} -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.payment.Bolt12Invoice import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.tlvField +import fr.acinq.eclair.{EncodedNodeId, ShortChannelId, UInt64} import scodec.bits.ByteVector /** Tlv types used inside the onion of an [[OnionMessage]]. */ @@ -148,11 +148,7 @@ object MessageOnionCodecs { import scodec.Codec import scodec.codecs._ - private val replyHopCodec: Codec[BlindedNode] = (("nodeId" | publicKey) :: ("encryptedData" | variableSizeBytes(uint16, bytes))).as[BlindedNode] - - val blindedRouteCodec: Codec[BlindedRoute] = (("firstNodeId" | publicKey) :: ("blinding" | publicKey) :: ("path" | listOfN(uint8, replyHopCodec).xmap[Seq[BlindedNode]](_.toSeq, _.toList))).as[BlindedRoute] - - private val replyPathCodec: Codec[ReplyPath] = tlvField(blindedRouteCodec) + private val replyPathCodec: Codec[ReplyPath] = tlvField(OfferCodecs.blindedRouteCodec) private val encryptedDataCodec: Codec[EncryptedData] = tlvField(bytes) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala index 0c1a42ede3..dc9431bb41 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferCodecs.scala @@ -17,15 +17,14 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.BlockHash -import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequestChain, InvoiceRequestPayerNote, InvoiceRequestQuantity, _} import fr.acinq.eclair.wire.protocol.TlvCodecs.{tlvField, tmillisatoshi, tu32, tu64overflow} import fr.acinq.eclair.{EncodedNodeId, TimestampSecond, UInt64} -import scodec.{Attempt, Codec, Err} import scodec.codecs._ +import scodec.{Attempt, Codec, Err} object OfferCodecs { private val offerChains: Codec[OfferChains] = tlvField(list(blockHash).xmap[Seq[BlockHash]](_.toSeq, _.toList)) @@ -42,17 +41,6 @@ object OfferCodecs { private val offerAbsoluteExpiry: Codec[OfferAbsoluteExpiry] = tlvField(tu64overflow.as[TimestampSecond]) - private val blindedNodeCodec: Codec[BlindedNode] = - (("nodeId" | publicKey) :: - ("encryptedData" | variableSizeBytes(uint16, bytes))).as[BlindedNode] - - private val blindedNodesCodec: Codec[Seq[BlindedNode]] = listOfN(uint8, blindedNodeCodec).xmap(_.toSeq, _.toList) - - private val blindedPathCodec: Codec[BlindedPath] = - (("firstNodeId" | publicKey) :: - ("blinding" | publicKey) :: - ("path" | blindedNodesCodec)).as[BlindedRoute].as[BlindedPath] - private val isNode1: Codec[Boolean] = uint8.narrow( n => if (n == 0) Attempt.Successful(true) else if (n == 1) Attempt.Successful(false) else Attempt.Failure(new Err.MatchingDiscriminatorNotFound(n)), b => if (b) 0 else 1 @@ -64,14 +52,18 @@ object OfferCodecs { val encodedNodeIdCodec: Codec[EncodedNodeId] = choice(shortChannelIdDirCodec.upcast[EncodedNodeId], publicKey.as[EncodedNodeId.Plain].upcast[EncodedNodeId]) - private val compactBlindedPathCodec: Codec[CompactBlindedPath] = - (("introductionNode" | shortChannelIdDirCodec) :: - ("blinding" | publicKey) :: - ("path" | blindedNodesCodec)).as[CompactBlindedPath] + private val blindedNodeCodec: Codec[BlindedNode] = + (("nodeId" | publicKey) :: + ("encryptedData" | variableSizeBytes(uint16, bytes))).as[BlindedNode] + + private val blindedNodesCodec: Codec[Seq[BlindedNode]] = listOfN(uint8, blindedNodeCodec).xmap(_.toSeq, _.toList) - val pathCodec: Codec[BlindedContactInfo] = choice(compactBlindedPathCodec.upcast[BlindedContactInfo], blindedPathCodec.upcast[BlindedContactInfo]) + val blindedRouteCodec: Codec[BlindedRoute] = + (("firstNodeId" | encodedNodeIdCodec) :: + ("blinding" | publicKey) :: + ("path" | blindedNodesCodec)).as[BlindedRoute] - private val offerPaths: Codec[OfferPaths] = tlvField(list(pathCodec).xmap[Seq[BlindedContactInfo]](_.toSeq, _.toList)) + private val offerPaths: Codec[OfferPaths] = tlvField(list(blindedRouteCodec).xmap[Seq[BlindedRoute]](_.toSeq, _.toList)) private val offerIssuer: Codec[OfferIssuer] = tlvField(utf8) @@ -133,7 +125,7 @@ object OfferCodecs { .typecase(UInt64(240), signature) ).complete - private val invoicePaths: Codec[InvoicePaths] = tlvField(list(pathCodec).xmap[Seq[BlindedContactInfo]](_.toSeq, _.toList)) + private val invoicePaths: Codec[InvoicePaths] = tlvField(list(blindedRouteCodec).xmap[Seq[BlindedRoute]](_.toSeq, _.toList)) val paymentInfo: Codec[PaymentInfo] = (("fee_base_msat" | millisatoshi32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala index 489aba3f01..de5bb18a7b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OfferTypes.scala @@ -19,12 +19,11 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey, XonlyPublicKey} import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, ByteVector64, Crypto, LexicographicalOrdering} -import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir -import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.wire.protocol.CommonCodecs.varint import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv -import fr.acinq.eclair.{Bolt12Feature, CltvExpiryDelta, Feature, Features, MilliSatoshi, RealShortChannelId, TimestampSecond, UInt64, nodeFee, randomBytes32} +import fr.acinq.eclair.{Bolt12Feature, CltvExpiryDelta, Feature, Features, MilliSatoshi, TimestampSecond, UInt64, nodeFee, randomBytes32} import scodec.Codec import scodec.bits.ByteVector import scodec.codecs.vector @@ -42,9 +41,7 @@ object OfferTypes { /** If the offer or invoice issuer doesn't want to hide their identity, they can directly share their public nodeId. */ case class RecipientNodeId(nodeId: PublicKey) extends ContactInfo /** If the offer or invoice issuer wants to hide their identity, they instead provide blinded paths. */ - sealed trait BlindedContactInfo extends ContactInfo - case class BlindedPath(route: BlindedRoute) extends BlindedContactInfo - case class CompactBlindedPath(introductionNode: ShortChannelIdDir, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) extends BlindedContactInfo + case class BlindedPath(route: BlindedRoute) extends ContactInfo // @formatter:on sealed trait Bolt12Tlv extends Tlv @@ -95,7 +92,7 @@ object OfferTypes { /** * Paths that can be used to retrieve an invoice. */ - case class OfferPaths(paths: Seq[BlindedContactInfo]) extends OfferTlv + case class OfferPaths(paths: Seq[BlindedRoute]) extends OfferTlv /** * Name of the offer creator. @@ -155,7 +152,7 @@ object OfferTypes { /** * Payment paths to send the payment to. */ - case class InvoicePaths(paths: Seq[BlindedContactInfo]) extends InvoiceTlv + case class InvoicePaths(paths: Seq[BlindedRoute]) extends InvoiceTlv case class PaymentInfo(feeBase: MilliSatoshi, feeProportionalMillionths: Long, @@ -239,7 +236,7 @@ object OfferTypes { val description: String = records.get[OfferDescription].get.description val features: Features[Bolt12Feature] = records.get[OfferFeatures].map(_.features.bolt12Features()).getOrElse(Features.empty) val expiry: Option[TimestampSecond] = records.get[OfferAbsoluteExpiry].map(_.absoluteExpiry) - private val paths: Option[Seq[BlindedContactInfo]] = records.get[OfferPaths].map(_.paths) + private val paths: Option[Seq[BlindedPath]] = records.get[OfferPaths].map(_.paths.map(BlindedPath)) val issuer: Option[String] = records.get[OfferIssuer].map(_.issuer) val quantityMax: Option[Long] = records.get[OfferQuantityMax].map(_.max).map { q => if (q == 0) Long.MaxValue else q } val nodeId: PublicKey = records.get[OfferNodeId].map(_.publicKey).get diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index d7bfcd03c4..e9e13c4a51 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -18,11 +18,11 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, PaymentBlindedContactInfo} +import fr.acinq.eclair.payment.{Bolt11Invoice, Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64, randomKey} +import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -186,7 +186,7 @@ object OnionPaymentPayloadTlv { case class AsyncPayment() extends OnionPaymentPayloadTlv /** Blinded paths to relay the payment to */ - case class OutgoingBlindedPaths(paths: Seq[PaymentBlindedContactInfo]) extends OnionPaymentPayloadTlv + case class OutgoingBlindedPaths(paths: Seq[PaymentBlindedRoute]) extends OnionPaymentPayloadTlv } object PaymentOnion { @@ -537,12 +537,12 @@ object PaymentOnionCodecs { private val trampolineOnion: Codec[TrampolineOnion] = tlvField(OnionRoutingCodecs.variableSizeOnionRoutingPacketCodec) - private val paymentBlindedContactInfo: Codec[PaymentBlindedContactInfo] = - (("route" | OfferCodecs.pathCodec) :: - ("paymentInfo" | OfferCodecs.paymentInfo)).as[PaymentBlindedContactInfo] + private val paymentBlindedRoute: Codec[PaymentBlindedRoute] = + (("route" | OfferCodecs.blindedRouteCodec) :: + ("paymentInfo" | OfferCodecs.paymentInfo)).as[PaymentBlindedRoute] private val outgoingBlindedPaths: Codec[OutgoingBlindedPaths] = - tlvField(list(paymentBlindedContactInfo).xmap[Seq[PaymentBlindedContactInfo]](_.toSeq, _.toList)) + tlvField(list(paymentBlindedRoute).xmap[Seq[PaymentBlindedRoute]](_.toSeq, _.toList)) private val keySend: Codec[KeySend] = tlvField(bytes32) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 27c1087f1e..45fdac0581 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedRoute, BlindedRouteDetails} import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomKey} +import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, MilliSatoshiLong, ShortChannelId, UInt64, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits._ @@ -421,7 +421,7 @@ class SphinxSpec extends AnyFunSuite { } // We now have a blinded route Bob -> Carol -> Dave -> Eve - val blindedRoute = BlindedRoute(bob.publicKey, blinding, blindedRouteStart.blindedNodes ++ blindedRouteEnd.blindedNodes) + val blindedRoute = BlindedRoute(EncodedNodeId(bob.publicKey), blinding, blindedRouteStart.blindedNodes ++ blindedRouteEnd.blindedNodes) assert(blindedRoute.blindedNodeIds == Seq( PublicKey(hex"03da173ad2aee2f701f17e59fbd16cb708906d69838a5f088e8123fb36e89a2c25"), PublicKey(hex"02e466727716f044290abf91a14a6d90e87487da160c2a3cbd0d465d7a78eb83a7"), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 7dc515d3fd..12c4b8a24c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -801,7 +801,7 @@ object PaymentsDbSpec { def createBolt12Invoice(amount: MilliSatoshi, payerKey: PrivateKey, recipientKey: PrivateKey, preimage: ByteVector32): Bolt12Invoice = { val offer = Offer(Some(amount), "some offer", recipientKey.publicKey, Features.empty, Block.TestnetGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, 789 msat, 1, Features.empty, payerKey, Block.TestnetGenesisBlock.hash) - val dummyRoute = PaymentBlindedContactInfo(BlindedPath(RouteBlinding.create(randomKey(), Seq(randomKey().publicKey), Seq(randomBytes(100))).route), PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 0 msat, Features.empty)) + val dummyRoute = PaymentBlindedRoute(RouteBlinding.create(randomKey(), Seq(randomKey().publicKey), Seq(randomBytes(100))).route, PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 0 msat, Features.empty)) Bolt12Invoice(invoiceRequest, preimage, recipientKey, 1 hour, Features.empty, Seq(dummyRoute)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala index fefdef46e2..c2a0667076 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala @@ -36,7 +36,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.ReplyPath import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv import fr.acinq.eclair.wire.protocol.{GenericTlv, NodeAnnouncement} -import fr.acinq.eclair.{EclairImpl, Features, MilliSatoshi, SendOnionMessageResponse, UInt64, randomBytes, randomKey} +import fr.acinq.eclair.{EclairImpl, EncodedNodeId, Features, MilliSatoshi, SendOnionMessageResponse, UInt64, randomBytes, randomKey} import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.ExecutionContext.Implicits.global @@ -83,7 +83,7 @@ class MessageIntegrationSpec extends IntegrationSpec { nodes("B").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(nodes("A").nodeParams.nodeId), IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("B").nodeParams.nodeId)), Recipient(nodes("B").nodeParams.nodeId, None)) - assert(blindedRoute.introductionNodeId == nodes("A").nodeParams.nodeId) + assert(blindedRoute.introductionNodeId == EncodedNodeId(nodes("A").nodeParams.nodeId)) alice.sendOnionMessage(None, Right(blindedRoute), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) @@ -101,7 +101,6 @@ class MessageIntegrationSpec extends IntegrationSpec { val recv = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) assert(recv.finalPayload.records.get[ReplyPath].nonEmpty) val replyPath = recv.finalPayload.records.get[ReplyPath].get.blindedRoute - assert(replyPath.introductionNodeId == nodes("B").nodeParams.nodeId) bob.sendOnionMessage(Some(Nil), Right(replyPath), expectsReply = false, hex"1d01ab") val res = probe.expectMsgType[SendOnionMessageResponse] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index e5159d9356..de2c93894f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -31,6 +31,7 @@ import fr.acinq.eclair.blockchain.bitcoind.rpc.BitcoinCoreClient import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel.{BroadcastChannelUpdate, PeriodicRefresh} import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer.PeerRoutingMessage @@ -44,9 +45,9 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendTra import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router.{GossipDecision, PublicChannel} import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, Router} -import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, Offer, OfferPaths} -import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, IncorrectOrUnknownPaymentDetails, OfferTypes} -import fr.acinq.eclair.{CltvExpiryDelta, EclairImpl, Features, Kit, MilliSatoshiLong, ShortChannelId, TimestampMilli, randomBytes32, randomKey} +import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} +import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, IncorrectOrUnknownPaymentDetails} +import fr.acinq.eclair.{CltvExpiryDelta, EclairImpl, EncodedNodeId, Features, Kit, MilliSatoshiLong, ShortChannelId, TimestampMilli, randomBytes32, randomKey} import org.json4s.JsonAST.{JString, JValue} import scodec.bits.{ByteVector, HexStringSyntax} @@ -693,9 +694,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { val chain = nodes("D").nodeParams.chainHash val pathId = randomBytes32() val offerPaths = Seq( - OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(nodes("G").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("D").nodeParams.nodeId, Some(pathId)))), - OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("D").nodeParams.nodeId, Some(pathId)))), - OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(nodes("E").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("D").nodeParams.nodeId, Some(pathId)))) + buildRoute(randomKey(), Seq(IntermediateNode(nodes("G").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("D").nodeParams.nodeId, Some(pathId))), + buildRoute(randomKey(), Seq(IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("D").nodeParams.nodeId, Some(pathId))), + buildRoute(randomKey(), Seq(IntermediateNode(nodes("E").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("D").nodeParams.nodeId, Some(pathId))) ) val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("D").nodeParams.features.bolt12Features(), chain, additionalTlvs = Set(OfferPaths(offerPaths))) val offerHandler = TypedProbe[HandlerCommand]()(nodes("D").system.toTyped) @@ -765,7 +766,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { val amount = 50_000_000 msat val chain = nodes("A").nodeParams.chainHash val pathId = randomBytes32() - val offerPath = OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(nodes("A").nodeParams.nodeId), IntermediateNode(nodes("A").nodeParams.nodeId)), Recipient(nodes("A").nodeParams.nodeId, Some(pathId)))) + val offerPath = buildRoute(randomKey(), Seq(IntermediateNode(nodes("A").nodeParams.nodeId), IntermediateNode(nodes("A").nodeParams.nodeId)), Recipient(nodes("A").nodeParams.nodeId, Some(pathId))) val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("A").nodeParams.features.bolt12Features(), chain, additionalTlvs = Set(OfferPaths(Seq(offerPath)))) val offerHandler = TypedProbe[HandlerCommand]()(nodes("A").system.toTyped) nodes("A").offerManager ! RegisterOffer(offer, recipientKey, Some(pathId), offerHandler.ref) @@ -799,7 +800,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { val amount = 10_000_000 msat val chain = nodes("C").nodeParams.chainHash val pathId = randomBytes32() - val offerPath = OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("C").nodeParams.nodeId, Some(pathId)))) + val offerPath = buildRoute(randomKey(), Seq(IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("C").nodeParams.nodeId, Some(pathId))) val offer = Offer(Some(amount), "tricky test offer", recipientKey.publicKey, nodes("C").nodeParams.features.bolt12Features(), chain, additionalTlvs = Set(OfferPaths(Seq(offerPath)))) val offerHandler = TypedProbe[HandlerCommand]()(nodes("C").system.toTyped) nodes("C").offerManager ! RegisterOffer(offer, recipientKey, Some(pathId), offerHandler.ref) @@ -875,7 +876,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { ShortChannelIdDir(channelBE.nodeId1 == nodes("B").nodeParams.nodeId, channelBE.shortChannelId) } val offerBlindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("C").nodeParams.nodeId)), Recipient(nodes("C").nodeParams.nodeId, Some(pathId))) - val offerPath = OfferTypes.CompactBlindedPath(scidDirEB, offerBlindedRoute.blindingKey, offerBlindedRoute.blindedNodes) + val offerPath = BlindedRoute(scidDirEB, offerBlindedRoute.blindingKey, offerBlindedRoute.blindedNodes) val offer = Offer(Some(amount), "test offer", recipientKey.publicKey, nodes("C").nodeParams.features.bolt12Features(), chain, additionalTlvs = Set(OfferPaths(Seq(offerPath)))) val offerHandler = TypedProbe[HandlerCommand]()(nodes("C").system.toTyped) nodes("C").offerManager ! RegisterOffer(offer, recipientKey, Some(pathId), offerHandler.ref) @@ -903,7 +904,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.recipientAmount == amount, paymentSent) assert(paymentSent.feesPaid >= 0.msat, paymentSent) val Some(invoice: Bolt12Invoice) = nodes("A").nodeParams.db.payments.listOutgoingPaymentsToOffer(offer.offerId).head.invoice - assert(invoice.blindedPaths.forall(_.route.isInstanceOf[CompactBlindedPath])) + assert(invoice.blindedPaths.forall(_.route.introductionNodeId.isInstanceOf[EncodedNodeId.ShortChannelIdDir])) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index 4a8de21eb5..19e2ccb5bc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -37,8 +37,8 @@ import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendSpo import fr.acinq.eclair.payment.send.{OfferPayment, PaymentLifecycle} import fr.acinq.eclair.testutils.FixtureSpec import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} -import fr.acinq.eclair.wire.protocol.{IncorrectOrUnknownPaymentDetails, InvalidOnionBlinding, OfferTypes} -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} +import fr.acinq.eclair.wire.protocol.{IncorrectOrUnknownPaymentDetails, InvalidOnionBlinding} +import fr.acinq.eclair.{CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} import org.scalatest.concurrent.IntegrationPatience import org.scalatest.{Tag, TestData} import scodec.bits.HexStringSyntax @@ -140,7 +140,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val offerPaths = routes.map(route => { val ourNodeId = route.nodes.last val intermediateNodes = route.nodes.dropRight(1).map(IntermediateNode(_)) ++ route.dummyHops.map(_ => IntermediateNode(ourNodeId)) - OfferTypes.BlindedPath(buildRoute(randomKey(), intermediateNodes, Recipient(ourNodeId, Some(pathId)))) + buildRoute(randomKey(), intermediateNodes, Recipient(ourNodeId, Some(pathId))) }) val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, recipient.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(offerPaths))) val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) @@ -359,10 +359,10 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val recipientKey = randomKey() val pathId = randomBytes32() - val blindedRoute = OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId), IntermediateNode(carol.nodeId)), Recipient(carol.nodeId, Some(pathId)))) + val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId), IntermediateNode(carol.nodeId)), Recipient(carol.nodeId, Some(pathId))) val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(blindedRoute)))) val scid_bc = getPeerChannels(bob, carol.nodeId).head.data.asInstanceOf[DATA_NORMAL].shortIds.real.toOption.get - val compactBlindedRoute = OfferTypes.BlindedPath(buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId, Some(scid_bc)), IntermediateNode(carol.nodeId, Some(ShortChannelId.toSelf))), Recipient(carol.nodeId, Some(pathId)))) + val compactBlindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId, EncodedNodeId(bob.nodeId), Some(scid_bc)), IntermediateNode(carol.nodeId, EncodedNodeId(carol.nodeId), Some(ShortChannelId.toSelf))), Recipient(carol.nodeId, Some(pathId))) val compactOffer = Offer(None, "test", recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(compactBlindedRoute)))) assert(compactOffer.toString.length < offer.toString.length) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 1f9a150db1..50dd5bafa4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.io import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe} import akka.actor.typed.ActorRef +import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory @@ -62,8 +63,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("relay with new connection") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(bobId)), message, RelayAll, None) @@ -76,8 +76,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("relay with existing peer") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(bobId)), message, RelayAll, None) @@ -90,8 +89,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("can't open new connection") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(bobId)), message, RelayAll, Some(probe.ref)) @@ -104,8 +102,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("no channel with previous node") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey relay ! RelayMessage(messageId, previousNodeId, Right(EncodedNodeId(bobId)), message, RelayChannelsOnly, Some(probe.ref)) @@ -121,8 +118,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("no channel with next node") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey relay ! RelayMessage(messageId, previousNodeId, Right(EncodedNodeId(bobId)), message, RelayChannelsOnly, Some(probe.ref)) @@ -142,8 +138,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("channels on both ends") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey relay ! RelayMessage(messageId, previousNodeId, Right(EncodedNodeId(bobId)), message, RelayChannelsOnly, None) @@ -162,8 +157,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("next node specified with channel id") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val scid = ShortChannelId(123456L) relay ! RelayMessage(messageId, randomKey().publicKey, Left(scid), message, RelayAll, None) @@ -181,8 +175,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("next node is compact node id") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) - assert(nextNode == bobId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val scid = RealShortChannelId(234567L) relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.ShortChannelIdDir(isNode1 = false, scid)), message, RelayAll, None) @@ -201,8 +194,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app test("next node is us as compact node id") { f => import f._ - val Right((nextNode, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(31), hex"f3ed")))) - assert(nextNode == aliceId) + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(31), hex"f3ed")))) val messageId = randomBytes32() val scid = RealShortChannelId(345678L) relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.ShortChannelIdDir(isNode1 = true, scid)), message, RelayAll, None) @@ -219,4 +211,27 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val OnionMessages.ReceiveMessage(payload) = OnionMessages.process(Bob.nodeParams.privateKey, messageToBob) assert(payload.records.unknown == Set(GenericTlv(UInt64(31), hex"f3ed"))) } + + test("relay to self and receive") { f => + import f._ + + val probe = TypedProbe[OnionMessages.ReceiveMessage]() + system.eventStream ! EventStream.Subscribe(probe.ref) + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq( + IntermediateNode(aliceId, EncodedNodeId(aliceId), outgoingChannel_opt = Some(ShortChannelId.toSelf)), + IntermediateNode(aliceId), + IntermediateNode(aliceId, EncodedNodeId.ShortChannelIdDir(isNode1 = false, scid = RealShortChannelId(123L))) + ), Recipient(aliceId, None), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd")))) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId(aliceId)), message, RelayAll, None) + + val getNodeId = router.expectMessageType[Router.GetNodeId] + assert(getNodeId.isNode1 == false) + assert(getNodeId.shortChannelId == RealShortChannelId(123L)) + getNodeId.replyTo ! Some(aliceId) + + val OnionMessages.ReceiveMessage(finalPayload) = probe.expectMessageType[OnionMessages.ReceiveMessage] + assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index 5f4eac9775..e07e6251a5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -402,7 +402,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer, isPersistent = false) val probe = TestProbe() - val Right((_, message)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) + val Right(message) = buildMessage(randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) probe.send(peerConnection, message) probe watch peerConnection probe.expectTerminated(peerConnection, max = 1500 millis) @@ -418,7 +418,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer, isPersistent = false) val probe = TestProbe() - val Right((_, message)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) + val Right(message) = buildMessage(randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) probe watch peerConnection probe.send(peerConnection, message) // The connection is still open for a short while. @@ -431,7 +431,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer, isPersistent = false) val probe = TestProbe() - val Right((_, message)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) + val Right(message) = buildMessage(randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) probe.send(peerConnection, message) assert(peerConnection.stateName == PeerConnection.CONNECTED) probe.send(peerConnection, ChannelReady(ByteVector32(hex"0000000000000000000000000000000000000000000000000000000000000000"), randomKey().publicKey)) @@ -444,7 +444,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi test("incoming rate limiting") { f => import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer, isPersistent = true) - val Right((_, message)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(nodeParams.nodeId, None), TlvStream.empty) + val Right(message) = buildMessage(randomKey(), randomKey(), Nil, Recipient(nodeParams.nodeId, None), TlvStream.empty) for (_ <- 1 to 30) { transport.send(peerConnection, message) } @@ -463,7 +463,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi test("outgoing rate limiting") { f => import f._ connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer, isPersistent = true) - val Right((_, message)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) + val Right(message) = buildMessage(randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) for (_ <- 1 to 30) { peer.send(peerConnection, message) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 021d4ed2f0..df59c76220 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -598,7 +598,7 @@ class PeerSpec extends FixtureSpec { test("reply to relay request") { f => import f._ connect(remoteNodeId, peer, peerConnection, switchboard, channels = Set(ChannelCodecsSpec.normal)) - val Right((_, msg)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) + val Right(msg) = buildMessage(randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) val messageId = randomBytes32() val probe = TestProbe() peer ! RelayOnionMessage(messageId, msg, Some(probe.ref.toTyped)) @@ -607,7 +607,7 @@ class PeerSpec extends FixtureSpec { test("reply to relay request disconnected") { f => import f._ - val Right((_, msg)) = buildMessage(nodeParams.privateKey, randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) + val Right(msg) = buildMessage(randomKey(), randomKey(), Nil, Recipient(remoteNodeId, None), TlvStream.empty) val messageId = randomBytes32() val probe = TestProbe() peer ! RelayOnionMessage(messageId, msg, Some(probe.ref.toTyped)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index 73f356ccee..47cf1f1ef2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -40,12 +40,10 @@ import scodec.bits.{ByteVector, HexStringSyntax} class OnionMessagesSpec extends AnyFunSuite { test("single-hop onion message without path_id") { - val nodeKey = randomKey() val sessionKey = randomKey() val blindingSecret = randomKey() val destination = randomKey() - val Right((nextNodeId, message)) = buildMessage(nodeKey, sessionKey, blindingSecret, Nil, Recipient(destination.publicKey, None), TlvStream.empty) - assert(nextNodeId == destination.publicKey) + val Right(message) = buildMessage(sessionKey, blindingSecret, Nil, Recipient(destination.publicKey, None), TlvStream.empty) process(destination, message) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.isEmpty) @@ -108,9 +106,9 @@ class OnionMessagesSpec extends AnyFunSuite { val onionForAlice = OnionMessage(blindingSecret.publicKey, packet) // Building the onion with functions from `OnionMessages` - val replyPath = buildRoute(blindingOverride, IntermediateNode(carol.publicKey, padding = Some(hex"0000000000000000000000000000000000000000000000000000000000000000000000")) :: Nil, Recipient(dave.publicKey, pathId = Some(hex"01234567"))) + val replyPath = buildRoute(blindingOverride, IntermediateNode(carol.publicKey, EncodedNodeId(carol.publicKey), padding = Some(hex"0000000000000000000000000000000000000000000000000000000000000000000000")) :: Nil, Recipient(dave.publicKey, pathId = Some(hex"01234567"))) assert(replyPath == routeFromCarol) - val Right((_, message)) = buildMessage(randomKey(), sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, BlindedPath(replyPath), TlvStream.empty) + val Right(message) = buildMessage(sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, BlindedPath(replyPath), TlvStream.empty) assert(message == onionForAlice) // Checking that the onion is relayed properly @@ -204,33 +202,35 @@ class OnionMessagesSpec extends AnyFunSuite { } test("build message with existing route") { - val nodeKey = randomKey() val sessionKey = randomKey() val blindingSecret = randomKey() val blindingOverride = randomKey() val destination = randomKey() val replyPath = buildRoute(blindingOverride, IntermediateNode(destination.publicKey) :: Nil, Recipient(destination.publicKey, pathId = Some(hex"01234567"))) assert(replyPath.blindingKey == blindingOverride.publicKey) - assert(replyPath.introductionNodeId == destination.publicKey) - val Right((nextNodeId, message)) = buildMessage(nodeKey, sessionKey, blindingSecret, Nil, BlindedPath(replyPath), TlvStream.empty) - assert(nextNodeId == destination.publicKey) + assert(replyPath.introductionNodeId == EncodedNodeId(destination.publicKey)) + val Right(message) = buildMessage(sessionKey, blindingSecret, Nil, BlindedPath(replyPath), TlvStream.empty) assert(message.blindingKey == blindingOverride.publicKey) // blindingSecret was not used as the replyPath was used as is process(destination, message) match { - case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(hex"01234567")) + case SendMessage(Right(EncodedNodeId.Plain(nextNodeId2)), message2) => + assert(nextNodeId2 == destination.publicKey) + process(destination, message2) match { + case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(hex"01234567")) + case x => fail(x.toString) + } case x => fail(x.toString) } } test("very large multi-hop onion message") { - val nodeKey = randomKey() val alice = randomKey() val bob = randomKey() val carol = randomKey() val sessionKey = randomKey() val blindingSecret = randomKey() val pathId = randomBytes(65201) - val Right((_, messageForAlice)) = buildMessage(nodeKey, sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) + val Right(messageForAlice) = buildMessage(sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) // Checking that the onion is relayed properly process(alice, messageForAlice) match { @@ -250,7 +250,6 @@ class OnionMessagesSpec extends AnyFunSuite { } test("too large multi-hop onion message") { - val nodeKey = randomKey() val alice = randomKey() val bob = randomKey() val carol = randomKey() @@ -259,7 +258,7 @@ class OnionMessagesSpec extends AnyFunSuite { val pathId = randomBytes(65202) - assert(buildMessage(nodeKey, sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) == Left(MessageTooLarge(65433))) + assert(buildMessage(sessionKey, blindingSecret, IntermediateNode(alice.publicKey) :: IntermediateNode(bob.publicKey) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) == Left(MessageTooLarge(65433))) } test("reference test vector") { @@ -286,7 +285,7 @@ class OnionMessagesSpec extends AnyFunSuite { Recipient(nodeKey.publicKey, Some(ByteVector.fromValidHex((json \ "path_id").extract[String])), (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) def makeIntermediateNode(nodeKey: PrivateKey, json: JValue): IntermediateNode = - IntermediateNode(nodeKey.publicKey, None, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) + IntermediateNode(nodeKey.publicKey, EncodedNodeId(nodeKey.publicKey), None, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) val blindingSecretBob = PrivateKey(ByteVector32.fromValidHex(((testVector \ "generate" \ "hops")(1) \ "blinding_secret").extract[String])) val pathId = ByteVector.fromValidHex(((testVector \ "generate" \ "hops")(3) \ "tlvs" \ "path_id").extract[String]) @@ -296,10 +295,10 @@ class OnionMessagesSpec extends AnyFunSuite { makeRecipient(dave, (testVector \ "generate" \ "hops")(3) \ "tlvs")) val blindingSecretAlice = PrivateKey(ByteVector32.fromValidHex(((testVector \ "generate" \ "hops")(0) \ "blinding_secret").extract[String])) val intermediateAlice = Seq(makeIntermediateNode(alice, (testVector \ "generate" \ "hops")(0) \ "tlvs")) - val Some(pathAliceToDave) = buildRouteFrom(alice, blindingSecretAlice, intermediateAlice, BlindedPath(pathBobToDave)) + val pathAliceToDave = buildRouteFrom(blindingSecretAlice, intermediateAlice, BlindedPath(pathBobToDave)) val expectedPath = BlindedRoute( - PublicKey(ByteVector.fromValidHex((testVector \ "route" \ "introduction_node_id").extract[String])), + EncodedNodeId(PublicKey(ByteVector.fromValidHex((testVector \ "route" \ "introduction_node_id").extract[String]))), PublicKey(ByteVector.fromValidHex((testVector \ "route" \ "blinding").extract[String])), Seq( BlindedNode( @@ -321,7 +320,7 @@ class OnionMessagesSpec extends AnyFunSuite { val sessionKey = PrivateKey(ByteVector32.fromValidHex((testVector \ "generate" \ "session_key").extract[String])) val messageContent = TlvStream(Set.empty[OnionMessagePayloadTlv], getCustomTlvs(testVector \ "onionmessage")) - val Right((_, message)) = buildMessage(alice, sessionKey, blindingSecretAlice, intermediateAlice, BlindedPath(pathBobToDave), messageContent) + val Right(message) = buildMessage(sessionKey, blindingSecretAlice, intermediateAlice, BlindedPath(pathBobToDave), messageContent) val encodedPacket = OnionRoutingCodecs.onionRoutingPacketCodec(1300).encode(message.onionRoutingPacket).require.bytes.toHex val expectedPacket = (testVector \ "onionmessage" \ "onion_message_packet").extract[String] assert(encodedPacket == expectedPacket) @@ -349,7 +348,6 @@ class OnionMessagesSpec extends AnyFunSuite { } test("route with channel ids") { - val nodeKey = randomKey() val alice = randomKey() val alice2bob = ShortChannelId(1) val bob = randomKey() @@ -358,7 +356,7 @@ class OnionMessagesSpec extends AnyFunSuite { val sessionKey = randomKey() val blindingSecret = randomKey() val pathId = randomBytes(64) - val Right((_, messageForAlice)) = buildMessage(nodeKey, sessionKey, blindingSecret, IntermediateNode(alice.publicKey, outgoingChannel_opt = Some(alice2bob)) :: IntermediateNode(bob.publicKey, outgoingChannel_opt = Some(bob2carol)) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) + val Right(messageForAlice) = buildMessage(sessionKey, blindingSecret, IntermediateNode(alice.publicKey, EncodedNodeId(alice.publicKey), outgoingChannel_opt = Some(alice2bob)) :: IntermediateNode(bob.publicKey, EncodedNodeId(bob.publicKey), outgoingChannel_opt = Some(bob2carol)) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) // Checking that the onion is relayed properly process(alice, messageForAlice) match { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala index e68c52d9fa..fbae409a98 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala @@ -22,7 +22,7 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Block -import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.io.MessageRelay.{Disconnected, Sent} import fr.acinq.eclair.io.PeerConnection.ConnectionResult @@ -35,12 +35,13 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteRequest} import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.{InvoiceRequest, ReplyPath} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.PathId -import fr.acinq.eclair.wire.protocol.{GenericTlv, MessageOnion, OfferTypes, OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{Features, MilliSatoshiLong, EncodedNodeId, NodeParams, RealShortChannelId, TestConstants, UInt64, randomKey} +import fr.acinq.eclair.wire.protocol.{GenericTlv, MessageOnion, OfferTypes, OnionMessage, OnionMessagePayloadTlv, TlvStream} +import fr.acinq.eclair.{EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, UInt64, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax +import scala.annotation.tailrec import scala.concurrent.duration.DurationInt class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { @@ -71,6 +72,18 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat peer.expectMessageType[Peer.RelayOnionMessage] } + @tailrec + private def receive(privateKeys: Seq[PrivateKey], message: OnionMessage): MessageOnion.FinalPayload = { + OnionMessages.process(privateKeys.head, message) match { + case OnionMessages.SendMessage(nextNode, nextMessage) if nextNode == Left(ShortChannelId.toSelf) || nextNode == Right(EncodedNodeId(privateKeys.head.publicKey)) => + receive(privateKeys, nextMessage) + case OnionMessages.SendMessage(nextNode, nextMessage) if nextNode == Right(EncodedNodeId(privateKeys(1).publicKey)) => + receive(privateKeys.tail, nextMessage) + case ReceiveMessage(finalPayload) => finalPayload + case _ => fail() + } + } + test("message forwarded only once") { f => import f._ @@ -89,8 +102,8 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) val replyPath = finalPayload.records.get[ReplyPath].get.blindedRoute - val Right((_, reply)) = buildMessage(recipientKey, randomKey(), randomKey(), Nil, BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(55), hex"1234")))) - val ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, reply) + val Right(reply) = buildMessage(randomKey(), randomKey(), Nil, BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(55), hex"1234")))) + val replyPayload = receive(Seq(recipientKey, nodeParams.privateKey), reply) testKit.system.eventStream ! EventStream.Publish(ReceiveMessage(replyPayload)) testKit.system.eventStream ! EventStream.Publish(ReceiveMessage(replyPayload)) @@ -138,8 +151,8 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat messageSender.expectMessage(NoReply) val replyPath = finalPayload.records.get[ReplyPath].get.blindedRoute - val Right((_, reply)) = buildMessage(recipientKey, randomKey(), randomKey(), Nil, BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(55), hex"1234")))) - val ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, reply) + val Right(reply) = buildMessage(randomKey(), randomKey(), Nil, BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(55), hex"1234")))) + val replyPayload = receive(Seq(recipientKey, nodeParams.privateKey), reply) testKit.system.eventStream ! EventStream.Publish(ReceiveMessage(replyPayload)) messageSender.expectNoMessage(10 millis) @@ -190,7 +203,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val offer = OfferTypes.Offer(None, "", randomKey().publicKey, Features.empty, Block.LivenetGenesisBlock.hash) val invoiceRequest = OfferTypes.InvoiceRequest(offer, 1000 msat, 1, Features.empty, randomKey(), Block.LivenetGenesisBlock.hash) - val replyPath = BlindedRoute(randomKey().publicKey, randomKey().publicKey, Seq(BlindedNode(randomKey().publicKey, hex""))) + val replyPath = BlindedRoute(EncodedNodeId(randomKey().publicKey), randomKey().publicKey, Seq(BlindedNode(randomKey().publicKey, hex""))) val invoiceRequestPayload = MessageOnion.InvoiceRequestPayload(TlvStream(InvoiceRequest(invoiceRequest.records), ReplyPath(replyPath)), TlvStream(PathId(hex"abcd"))) postman ! WrappedMessage(invoiceRequestPayload) @@ -222,19 +235,12 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(payload.records.unknown == Set(GenericTlv(UInt64(11), hex"012345"))) assert(payload.records.get[ReplyPath].nonEmpty) val replyPath = payload.records.get[ReplyPath].get.blindedRoute - assert(replyPath.introductionNodeId == d.publicKey) + assert(replyPath.introductionNodeId == EncodedNodeId(d.publicKey)) assert(replyPath.length >= nodeParams.onionMessageConfig.minIntermediateHops) assert(nodeParams.onionMessageConfig.minIntermediateHops > 5) - val Right((next5, reply)) = OnionMessages.buildMessage(d, randomKey(), randomKey(), Nil, OnionMessages.BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(13), hex"6789")))) - assert(next5 == c.publicKey) - val OnionMessages.SendMessage(Right(next6), message6) = OnionMessages.process(c, reply) - assert(next6 == EncodedNodeId(b.publicKey)) - val OnionMessages.SendMessage(Right(next7), message7) = OnionMessages.process(b, message6) - assert(next7 == EncodedNodeId(a.publicKey)) - val OnionMessages.SendMessage(Right(next8), message8) = OnionMessages.process(a, message7) - assert(next8 == EncodedNodeId(nodeParams.nodeId)) - val OnionMessages.ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, message8) + val Right(reply) = OnionMessages.buildMessage(randomKey(), randomKey(), Nil, OnionMessages.BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(13), hex"6789")))) + val replyPayload = receive(Seq(d, c, b, a, nodeParams.privateKey), reply) postman ! WrappedMessage(replyPayload) assert(replyPayload.records.unknown == Set(GenericTlv(UInt64(13), hex"6789"))) @@ -246,7 +252,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val recipientKey = randomKey() val route = buildRoute(randomKey(), Seq(), Recipient(recipientKey.publicKey, None)) - val compactRoute = OfferTypes.CompactBlindedPath(EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(1234)), route.blindingKey, route.blindedNodes) + val compactRoute = OfferTypes.BlindedPath(route.copy(introductionNodeId = EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(1234)))) postman ! SendMessage(compactRoute, FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) val getNodeId = router.expectMessageType[Router.GetNodeId] @@ -280,7 +286,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val recipientKey = randomKey() val route = buildRoute(randomKey(), Seq(IntermediateNode(nodeParams.nodeId)), Recipient(recipientKey.publicKey, None)) - val compactRoute = OfferTypes.CompactBlindedPath(EncodedNodeId.ShortChannelIdDir(isNode1 = true, RealShortChannelId(1234)), route.blindingKey, route.blindedNodes) + val compactRoute = OfferTypes.BlindedPath(route.copy(introductionNodeId = EncodedNodeId.ShortChannelIdDir(isNode1 = true, RealShortChannelId(1234)))) postman ! SendMessage(compactRoute, FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) val getNodeId = router.expectMessageType[Router.GetNodeId] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala index 537b5a9ba1..2c4dd9707a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/Bolt12InvoiceSpec.scala @@ -49,9 +49,9 @@ class Bolt12InvoiceSpec extends AnyFunSuite { signedInvoice } - def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedContactInfo = { + def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes - PaymentBlindedContactInfo(OfferTypes.BlindedPath(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route), PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) + PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) } test("check invoice signature") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index e767e2c5b5..8564c14c28 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -38,7 +38,7 @@ import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv._ import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, TimestampMilli, TimestampMilliLong, randomBytes, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Feature, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, TestKitBaseClass, TimestampMilli, TimestampMilliLong, randomBytes, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} @@ -294,19 +294,19 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.description == Left("a blinded coffee please")) assert(invoice.invoiceRequest.offer == offer) assert(invoice.blindedPaths.length == 3) - assert(invoice.blindedPaths(0).route.asInstanceOf[OfferTypes.BlindedPath].route.blindedNodeIds.length == 4) - assert(invoice.blindedPaths(0).route.asInstanceOf[OfferTypes.BlindedPath].route.introductionNodeId == a) + assert(invoice.blindedPaths(0).route.blindedNodeIds.length == 4) + assert(invoice.blindedPaths(0).route.introductionNodeId == EncodedNodeId(a)) assert(invoice.blindedPaths(0).paymentInfo == PaymentInfo(1950 msat, 0, CltvExpiryDelta(193), 1 msat, 25_000 msat, Features.empty)) - assert(invoice.blindedPaths(1).route.asInstanceOf[OfferTypes.BlindedPath].route.blindedNodeIds.length == 4) - assert(invoice.blindedPaths(1).route.asInstanceOf[OfferTypes.BlindedPath].route.introductionNodeId == c) + assert(invoice.blindedPaths(1).route.blindedNodeIds.length == 4) + assert(invoice.blindedPaths(1).route.introductionNodeId == EncodedNodeId(c)) assert(invoice.blindedPaths(1).paymentInfo == PaymentInfo(400 msat, 0, CltvExpiryDelta(183), 1 msat, 25_000 msat, Features.empty)) - assert(invoice.blindedPaths(2).route.asInstanceOf[OfferTypes.BlindedPath].route.blindedNodeIds.length == 1) - assert(invoice.blindedPaths(2).route.asInstanceOf[OfferTypes.BlindedPath].route.introductionNodeId == d) + assert(invoice.blindedPaths(2).route.blindedNodeIds.length == 1) + assert(invoice.blindedPaths(2).route.introductionNodeId == EncodedNodeId(d)) assert(invoice.blindedPaths(2).paymentInfo == PaymentInfo(0 msat, 0, CltvExpiryDelta(18), 0 msat, 25_000 msat, Features.empty)) // Offer invoices shouldn't be stored in the DB until we receive a payment for it. assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) // Check that all non-final encrypted payloads for blinded routes have the same length. - assert(invoice.blindedPaths.flatMap(_.route.asInstanceOf[OfferTypes.BlindedPath].route.encryptedPayloads.dropRight(1)).map(_.length).toSet.size == 1) + assert(invoice.blindedPaths.flatMap(_.route.encryptedPayloads.dropRight(1)).map(_.length).toSet.size == 1) } test("Invoice generation with route blinding should fail when router returns an error") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 90c9647434..8d98d1ef36 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -29,15 +29,16 @@ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.OutgoingPaymentPacket.Upstream import fr.acinq.eclair.payment.PaymentPacketSpec._ import fr.acinq.eclair.payment.PaymentSent.PartialPayment +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.ResolvedPath import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment import fr.acinq.eclair.payment.send.PaymentError.UnsupportedFeatures import fr.acinq.eclair.payment.send.PaymentInitiator._ import fr.acinq.eclair.payment.send._ import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.router.{BlindedRouteCreation, RouteNotFound} -import fr.acinq.eclair.wire.protocol.OfferTypes.{BlindedPath, InvoiceRequest, Offer} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshiLong, NodeParams, PaymentFinalExpiryConf, TestConstants, TestKitBaseClass, TimestampSecond, UnknownFeature, randomBytes32, randomKey} +import fr.acinq.eclair.{Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Feature, Features, MilliSatoshiLong, NodeParams, PaymentFinalExpiryConf, TestConstants, TestKitBaseClass, TimestampSecond, UnknownFeature, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} @@ -296,16 +297,16 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike def createBolt12Invoice(features: Features[Bolt12Feature], payerKey: PrivateKey): Bolt12Invoice = { val offer = Offer(None, "Bolt12 r0cks", e, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, finalAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) - val blindedRoute = OfferTypes.BlindedPath(BlindedRouteCreation.createBlindedRouteWithoutHops(e, hex"2a2a2a2a", 1 msat, CltvExpiry(500_000)).route) + val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(e, hex"2a2a2a2a", 1 msat, CltvExpiry(500_000)).route val paymentInfo = OfferTypes.PaymentInfo(1_000 msat, 0, CltvExpiryDelta(24), 0 msat, finalAmount, Features.empty) - Bolt12Invoice(invoiceRequest, paymentPreimage, priv_e.privateKey, 300 seconds, features, Seq(PaymentBlindedContactInfo(blindedRoute, paymentInfo))) + Bolt12Invoice(invoiceRequest, paymentPreimage, priv_e.privateKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) } test("forward single-part blinded payment") { f => import f._ val payerKey = randomKey() val invoice = createBolt12Invoice(Features.empty, payerKey) - val resolvedPaths = invoice.blindedPaths.map(path => PaymentBlindedRoute(path.route.asInstanceOf[BlindedPath].route, path.paymentInfo)) + val resolvedPaths = invoice.blindedPaths.map(path => ResolvedPath(path, path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) val req = SendPaymentToNode(sender.ref, finalAmount, invoice, resolvedPaths, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams, payerKey_opt = Some(payerKey)) sender.send(initiator, req) val id = sender.expectMsgType[UUID] @@ -336,7 +337,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val payerKey = randomKey() val invoice = createBolt12Invoice(Features(BasicMultiPartPayment -> Optional), payerKey) - val resolvedPaths = invoice.blindedPaths.map(path => PaymentBlindedRoute(path.route.asInstanceOf[BlindedPath].route, path.paymentInfo)) + val resolvedPaths = invoice.blindedPaths.map(path => ResolvedPath(path, path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) val req = SendPaymentToNode(sender.ref, finalAmount, invoice, resolvedPaths, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams, payerKey_opt = Some(payerKey)) sender.send(initiator, req) val id = sender.expectMsgType[UUID] @@ -365,7 +366,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike test("reject blinded payment when route blinding deactivated", Tag(Tags.DisableRouteBlinding)) { f => import f._ val invoice = createBolt12Invoice(Features(BasicMultiPartPayment -> Optional), randomKey()) - val resolvedPaths = invoice.blindedPaths.map(path => PaymentBlindedRoute(path.route.asInstanceOf[BlindedPath].route, path.paymentInfo)) + val resolvedPaths = invoice.blindedPaths.map(path => ResolvedPath(path, path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) val req = SendPaymentToNode(sender.ref, finalAmount, invoice, resolvedPaths, 1, routeParams = nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams) sender.send(initiator, req) val id = sender.expectMsgType[UUID] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 624d566df9..ac2cb4af2d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -27,17 +27,18 @@ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.crypto.{ShaChain, Sphinx} import fr.acinq.eclair.payment.IncomingPaymentPacket.{ChannelRelayPacket, FinalPacket, RelayToTrampolinePacket, decrypt} import fr.acinq.eclair.payment.OutgoingPaymentPacket._ +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.ResolvedPath import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient, TrampolineRecipient} import fr.acinq.eclair.router.BaseRouterSpec.{blindedRouteFromHops, channelHopFromUpdate} import fr.acinq.eclair.router.BlindedRouteCreation import fr.acinq.eclair.router.Router.{NodeHop, Route} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions.InputInfo -import fr.acinq.eclair.wire.protocol.OfferTypes.{BlindedPath, InvoiceRequest, Offer, PaymentInfo} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv.{AmountToForward, OutgoingCltv, PaymentData} import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampMilli, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} +import fr.acinq.eclair.{BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, TestConstants, TimestampMilli, TimestampSecondLong, UInt64, nodeFee, randomBytes32, randomKey} import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -218,10 +219,10 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val features = Features[Bolt12Feature](BasicMultiPartPayment -> Optional) val offer = Offer(None, "Bolt12 r0cks", recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) - val blindedRoute = OfferTypes.BlindedPath(BlindedRouteCreation.createBlindedRouteWithoutHops(c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route) + val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route val paymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 1 msat, amount_bc, Features.empty) - val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, recipientKey, 300 seconds, features, Seq(PaymentBlindedContactInfo(blindedRoute, paymentInfo))) - val resolvedPaths = invoice.blindedPaths.map(path => PaymentBlindedRoute(path.route.asInstanceOf[BlindedPath].route, path.paymentInfo)) + val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, recipientKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + val resolvedPaths = invoice.blindedPaths.map(path => ResolvedPath(path, path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) val recipient = BlindedRecipient(invoice, resolvedPaths, amount_bc, expiry_bc, Set.empty) val hops = Seq(channelHopFromUpdate(a, b, channelUpdate_ab), channelHopFromUpdate(b, c, channelUpdate_bc)) val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(amount_bc, hops, Some(recipient.blindedHops.head)), recipient) @@ -493,10 +494,10 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) // We send the wrong blinded payload to the introduction node. val tmpBlindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Seq(channelHopFromUpdate(b, c, channelUpdate_bc)), hex"deadbeef", 1 msat, CltvExpiry(500_000)).route - val blindedRoute = OfferTypes.BlindedPath(tmpBlindedRoute.copy(blindedNodes = tmpBlindedRoute.blindedNodes.reverse)) + val blindedRoute = tmpBlindedRoute.copy(blindedNodes = tmpBlindedRoute.blindedNodes.reverse) val paymentInfo = OfferTypes.PaymentInfo(fee_b, 0, channelUpdate_bc.cltvExpiryDelta, 0 msat, amount_bc, Features.empty) - val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, priv_c.privateKey, 300 seconds, features, Seq(PaymentBlindedContactInfo(blindedRoute, paymentInfo))) - val resolvedPaths = invoice.blindedPaths.map(path => PaymentBlindedRoute(path.route.asInstanceOf[BlindedPath].route, path.paymentInfo)) + val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, priv_c.privateKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) + val resolvedPaths = invoice.blindedPaths.map(path => ResolvedPath(path, path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) val recipient = BlindedRecipient(invoice, resolvedPaths, amount_bc, expiry_bc, Set.empty) val route = Route(amount_bc, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), Some(recipient.blindedHops.head)) (route, recipient) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala index 9f47005c96..bf273c2503 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala @@ -88,7 +88,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app import f._ assert(invoice.blindedPaths.length == 1) - val blindedPath = invoice.blindedPaths.head.route.asInstanceOf[OfferTypes.BlindedPath].route + val blindedPath = invoice.blindedPaths.head.route val Right(RouteBlindingDecryptedData(encryptedDataTlvs, _)) = RouteBlindingEncryptedDataCodecs.decode(nodeParams.privateKey, blindedPath.blindingKey, blindedPath.encryptedPayloads.head) val paymentTlvs = TlvStream[OnionPaymentPayloadTlv]( OnionPaymentPayloadTlv.AmountToForward(invoice.amount), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 00bd63127f..d72699e2fa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -828,9 +828,9 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } } - def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedContactInfo = { + def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes - PaymentBlindedContactInfo(OfferTypes.BlindedPath(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route), PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) + PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) } test("relay to blinded paths without multi-part") { f => @@ -918,9 +918,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val offer = Offer(None, "test offer", outgoingNodeId, Features.empty, chain) val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) - val BlindedPath(blindedRoute) = paymentBlindedRoute.route val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = CompactBlindedPath(scidDir, blindedRoute.blindingKey, blindedRoute.blindedNodes)) + val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( incoming.innerPayload.amountToForward, outgoingExpiry, invoice @@ -931,7 +930,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val getNodeId = router.expectMessageType[Router.GetNodeId] assert(getNodeId.isNode1 == scidDir.isNode1) assert(getNodeId.shortChannelId == scidDir.scid) - getNodeId.replyTo ! Some(blindedRoute.introductionNodeId) + getNodeId.replyTo ! Some(outgoingNodeId) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) @@ -966,9 +965,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val offer = Offer(None, "test offer", outgoingNodeId, Features.empty, chain) val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) - val BlindedPath(blindedRoute) = paymentBlindedRoute.route val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = CompactBlindedPath(scidDir, blindedRoute.blindingKey, blindedRoute.blindedNodes)) + val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( incoming.innerPayload.amountToForward, outgoingExpiry, invoice diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala index 4f76d7a5d9..b2ba9642de 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala @@ -27,13 +27,13 @@ import fr.acinq.eclair.message.OnionMessages.RoutingStrategy.FindRoute import fr.acinq.eclair.message.Postman import fr.acinq.eclair.payment.send.OfferPayment._ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentToNode -import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedContactInfo} +import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.wire.protocol.MessageOnion.InvoicePayload import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax @@ -75,7 +75,7 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs) val preimage = randomBytes32() - val paymentRoute = PaymentBlindedContactInfo(OfferTypes.BlindedPath(RouteBlinding.create(randomKey(), Seq(merchantKey.publicKey), Seq(hex"7777")).route), PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty)) + val paymentRoute = PaymentBlindedRoute(RouteBlinding.create(randomKey(), Seq(merchantKey.publicKey), Seq(hex"7777")).route, PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty)) val invoice = Bolt12Invoice(invoiceRequest, preimage, merchantKey, 1 minute, Features.empty, Seq(paymentRoute)) replyTo ! Postman.Response(InvoicePayload(TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), TlvStream.empty)) val send = paymentInitiator.expectMsgType[SendPaymentToNode] @@ -122,7 +122,7 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs) val preimage = randomBytes32() - val paymentRoute = PaymentBlindedContactInfo(OfferTypes.BlindedPath(RouteBlinding.create(randomKey(), Seq(merchantKey.publicKey), Seq(hex"7777")).route), PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty)) + val paymentRoute = PaymentBlindedRoute(RouteBlinding.create(randomKey(), Seq(merchantKey.publicKey), Seq(hex"7777")).route, PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty)) val invoice = Bolt12Invoice(invoiceRequest, preimage, randomKey(), 1 minute, Features.empty, Seq(paymentRoute)) replyTo ! Postman.Response(InvoicePayload(TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), TlvStream.empty)) @@ -149,12 +149,12 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val preimage = randomBytes32() val blindedRoutes = Seq.fill(6)(RouteBlinding.create(randomKey(), Seq.fill(3)(randomKey().publicKey), Seq.fill(3)(randomBytes(10))).route) val paymentRoutes = Seq( - PaymentBlindedContactInfo(OfferTypes.BlindedPath(blindedRoutes(0)), PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty)), - PaymentBlindedContactInfo(OfferTypes.CompactBlindedPath(ShortChannelIdDir(isNode1 = true, RealShortChannelId(11111)), blindedRoutes(1).blindingKey, blindedRoutes(1).blindedNodes), PaymentInfo(1 msat, 11, CltvExpiryDelta(111), 0 msat, 1_000_000_000 msat, Features.empty)), - PaymentBlindedContactInfo(OfferTypes.BlindedPath(blindedRoutes(2)), PaymentInfo(2 msat, 22, CltvExpiryDelta(222), 0 msat, 1_000_000_000 msat, Features.empty)), - PaymentBlindedContactInfo(OfferTypes.CompactBlindedPath(ShortChannelIdDir(isNode1 = false, RealShortChannelId(33333)), blindedRoutes(3).blindingKey, blindedRoutes(3).blindedNodes), PaymentInfo(3 msat, 33, CltvExpiryDelta(333), 0 msat, 1_000_000_000 msat, Features.empty)), - PaymentBlindedContactInfo(OfferTypes.CompactBlindedPath(ShortChannelIdDir(isNode1 = false, RealShortChannelId(44444)), blindedRoutes(4).blindingKey, blindedRoutes(4).blindedNodes), PaymentInfo(4 msat, 44, CltvExpiryDelta(444), 0 msat, 1_000_000_000 msat, Features.empty)), - PaymentBlindedContactInfo(OfferTypes.BlindedPath(blindedRoutes(5)), PaymentInfo(5 msat, 55, CltvExpiryDelta(555), 0 msat, 1_000_000_000 msat, Features.empty)), + PaymentBlindedRoute(blindedRoutes(0), PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty)), + PaymentBlindedRoute(blindedRoutes(1).copy(introductionNodeId = ShortChannelIdDir(isNode1 = true, RealShortChannelId(11111))), PaymentInfo(1 msat, 11, CltvExpiryDelta(111), 0 msat, 1_000_000_000 msat, Features.empty)), + PaymentBlindedRoute(blindedRoutes(2), PaymentInfo(2 msat, 22, CltvExpiryDelta(222), 0 msat, 1_000_000_000 msat, Features.empty)), + PaymentBlindedRoute(blindedRoutes(3).copy(introductionNodeId = ShortChannelIdDir(isNode1 = false, RealShortChannelId(33333))), PaymentInfo(3 msat, 33, CltvExpiryDelta(333), 0 msat, 1_000_000_000 msat, Features.empty)), + PaymentBlindedRoute(blindedRoutes(4).copy(introductionNodeId = ShortChannelIdDir(isNode1 = false, RealShortChannelId(44444))), PaymentInfo(4 msat, 44, CltvExpiryDelta(444), 0 msat, 1_000_000_000 msat, Features.empty)), + PaymentBlindedRoute(blindedRoutes(5), PaymentInfo(5 msat, 55, CltvExpiryDelta(555), 0 msat, 1_000_000_000 msat, Features.empty)), ) val invoice = Bolt12Invoice(invoiceRequest, preimage, merchantKey, 1 minute, Features.empty, paymentRoutes) replyTo ! Postman.Response(InvoicePayload(TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), TlvStream.empty)) @@ -162,7 +162,7 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val getNode1 = router.expectMsgType[Router.GetNodeId] assert(getNode1.isNode1) assert(getNode1.shortChannelId == RealShortChannelId(11111)) - getNode1.replyTo ! Some(blindedRoutes(1).introductionNodeId) + getNode1.replyTo ! Some(blindedRoutes(1).introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey) val getNode3 = router.expectMsgType[Router.GetNodeId] assert(!getNode3.isNode1) @@ -172,12 +172,12 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val getNode4 = router.expectMsgType[Router.GetNodeId] assert(!getNode4.isNode1) assert(getNode4.shortChannelId == RealShortChannelId(44444)) - getNode4.replyTo ! Some(blindedRoutes(4).introductionNodeId) + getNode4.replyTo ! Some(blindedRoutes(4).introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey) val send = paymentInitiator.expectMsgType[SendPaymentToNode] assert(send.invoice == invoice) - assert(send.resolvedPaths.map(_.route) == Seq(blindedRoutes(0), blindedRoutes(1), blindedRoutes(2), blindedRoutes(4), blindedRoutes(5))) - assert(send.resolvedPaths.map(_.paymentInfo.feeBase) == Seq(0 msat, 1 msat, 2 msat, 4 msat, 5 msat)) + assert(send.resolvedPaths.map(_.introductionNodeId) == Seq(blindedRoutes(0), blindedRoutes(1), blindedRoutes(2), blindedRoutes(4), blindedRoutes(5)).map(_.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) + assert(send.resolvedPaths.map(_.blindedPath.paymentInfo.feeBase) == Seq(0 msat, 1 msat, 2 msat, 4 msat, 5 msat)) TypedProbe().expectTerminated(offerPayment) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index 633cf0bea9..8dc15cde55 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -31,12 +31,13 @@ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.payment.send.BlindedRecipient -import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedContactInfo, PaymentBlindedRoute} +import fr.acinq.eclair.payment.send.CompactBlindedPathsResolver.ResolvedPath +import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.router.Announcements._ import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts -import fr.acinq.eclair.wire.protocol.OfferTypes.{BlindedPath, InvoiceRequest, Offer} +import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol._ import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -270,12 +271,12 @@ object BaseRouterSpec { val offer = Offer(None, "Bolt12 r0cks", recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, amount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) val blindedRoutes = paths.map(hops => { - val blindedRoute = OfferTypes.BlindedPath(BlindedRouteCreation.createBlindedRouteFromHops(hops, pathId, 1 msat, routeExpiry).route) + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(hops, pathId, 1 msat, routeExpiry).route val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(amount, hops, Channel.MIN_CLTV_EXPIRY_DELTA) - PaymentBlindedContactInfo(blindedRoute, paymentInfo) + PaymentBlindedRoute(blindedRoute, paymentInfo) }) val invoice = Bolt12Invoice(invoiceRequest, preimage, recipientKey, 300 seconds, features, blindedRoutes) - val resolvedPaths = invoice.blindedPaths.map(path => PaymentBlindedRoute(path.route.asInstanceOf[BlindedPath].route, path.paymentInfo)) + val resolvedPaths = invoice.blindedPaths.map(path => ResolvedPath(path, path.route.introductionNodeId.asInstanceOf[EncodedNodeId.Plain].publicKey)) val recipient = BlindedRecipient(invoice, resolvedPaths, amount, expiry, Set.empty) (invoice, recipient) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala index 5510bb594f..896fff9bf6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.router import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.wire.protocol.{BlindedRouteData, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.{ParallelTestExecution, Tag} @@ -31,7 +31,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { val a = randomKey() val pathId = randomBytes32() val route = createBlindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) - assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.introductionNodeId == EncodedNodeId(a.publicKey)) assert(route.route.encryptedPayloads.length == 1) assert(route.route.blindingKey == route.lastBlinding) val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head) @@ -48,7 +48,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { ChannelHop(scid2, b.publicKey, c.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid2, b.publicKey, c.publicKey, 20 msat, 150, cltvDelta = CltvExpiryDelta(600)))), ) val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) - assert(route.route.introductionNodeId == a.publicKey) + assert(route.route.introductionNodeId == EncodedNodeId(a.publicKey)) assert(route.route.encryptedPayloads.length == 3) val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0)) assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala index 009dec657d..26345bfeef 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala @@ -4,7 +4,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.crypto.Sphinx.RouteBlinding -import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedContactInfo} +import fr.acinq.eclair.payment.{Bolt12Invoice, PaymentBlindedRoute} import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, IntermediatePayload, InvalidResponsePayload, InvoiceErrorPayload, InvoicePayload, InvoiceRequestPayload} import fr.acinq.eclair.wire.protocol.MessageOnionCodecs._ import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo @@ -12,7 +12,7 @@ import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, OutgoingNodeId, PathId, PaymentConstraints, PaymentRelay} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshiLong, EncodedNodeId, UInt64, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, UInt64, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuiteLike import scodec.bits.{ByteVector, HexStringSyntax} @@ -95,7 +95,7 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike { val payerKey = randomKey() val request = OfferTypes.InvoiceRequest(offer, 100_000 msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(randomBytes32()), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes - val route = PaymentBlindedContactInfo(OfferTypes.BlindedPath(Sphinx.RouteBlinding.create(randomKey(), Seq(nodeKey.publicKey), Seq(selfPayload)).route), PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) + val route = PaymentBlindedRoute(Sphinx.RouteBlinding.create(randomKey(), Seq(nodeKey.publicKey), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) val invoice = Bolt12Invoice(request, randomBytes32(), nodeKey, 300 seconds, Features.empty, Seq(route)) val testCasesInvalid = Seq[TlvStream[OnionMessagePayloadTlv]]( @@ -106,7 +106,7 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike { // Invoice and unknown TLV. TlvStream(Set[OnionMessagePayloadTlv](EncryptedData(hex""), Invoice(invoice.records)), Set(GenericTlv(UInt64(1), hex""))), // Invoice and ReplyPath. - TlvStream(EncryptedData(hex""), Invoice(invoice.records), ReplyPath(route.route.asInstanceOf[OfferTypes.BlindedPath].route)), + TlvStream(EncryptedData(hex""), Invoice(invoice.records), ReplyPath(route.route)), // Invoice and InvoiceError. TlvStream(EncryptedData(hex""), Invoice(invoice.records), InvoiceError(TlvStream(OfferTypes.Error("")))), // InvoiceRequest without ReplyPath. @@ -118,7 +118,7 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike { assert(finalPayload.isInstanceOf[InvalidResponsePayload]) } - val Right(invoiceRequestPayload) = FinalPayload.validate(TlvStream(EncryptedData(hex""), InvoiceRequest(request.records), ReplyPath(route.route.asInstanceOf[OfferTypes.BlindedPath].route)), TlvStream.empty) + val Right(invoiceRequestPayload) = FinalPayload.validate(TlvStream(EncryptedData(hex""), InvoiceRequest(request.records), ReplyPath(route.route)), TlvStream.empty) assert(invoiceRequestPayload.isInstanceOf[InvoiceRequestPayload]) val Right(invoicePayload) = FinalPayload.validate(TlvStream(EncryptedData(hex""), Invoice(invoice.records)), TlvStream.empty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala index 0f21ba07e2..27290305c2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OfferTypesSpec.scala @@ -19,7 +19,6 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.Bech32 import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32} -import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.BasicMultiPartPayment import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} @@ -264,23 +263,23 @@ class OfferTypesSpec extends AnyFunSuite { } test("compact blinded route") { - case class TestCase(encoded: ByteVector, decoded: BlindedContactInfo) + case class TestCase(encoded: ByteVector, decoded: BlindedRoute) val testCases = Seq( TestCase(hex"00 00000000000004d2 0379b470d00b78ded936f8972a0f3ecda2bb6e6df40dcd581dbaeb3742b30008ff 01 02fba71b72623187dd24670110eec870e28b848f255ba2edc0486d3a8e89ec44b7 0002 1dea", - CompactBlindedPath(ShortChannelIdDir(isNode1 = true, RealShortChannelId(1234)), PublicKey(hex"0379b470d00b78ded936f8972a0f3ecda2bb6e6df40dcd581dbaeb3742b30008ff"), Seq(BlindedNode(PublicKey(hex"02fba71b72623187dd24670110eec870e28b848f255ba2edc0486d3a8e89ec44b7"), hex"1dea")))), + BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1 = true, RealShortChannelId(1234)), PublicKey(hex"0379b470d00b78ded936f8972a0f3ecda2bb6e6df40dcd581dbaeb3742b30008ff"), Seq(BlindedNode(PublicKey(hex"02fba71b72623187dd24670110eec870e28b848f255ba2edc0486d3a8e89ec44b7"), hex"1dea")))), TestCase(hex"01 000000000000ddd5 0353a081bb02d6e361be3df3e92b41b788ca65667f6ea0c01e2bfa03664460ef86 01 03bce3f0cdb4172caac82ec8a9251eb35df1201bdcb977c5a03f3624ec4156a65f 0003 c0ffee", - CompactBlindedPath(ShortChannelIdDir(isNode1 = false, RealShortChannelId(56789)), PublicKey(hex"0353a081bb02d6e361be3df3e92b41b788ca65667f6ea0c01e2bfa03664460ef86"), Seq(BlindedNode(PublicKey(hex"03bce3f0cdb4172caac82ec8a9251eb35df1201bdcb977c5a03f3624ec4156a65f"), hex"c0ffee")))), + BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(56789)), PublicKey(hex"0353a081bb02d6e361be3df3e92b41b788ca65667f6ea0c01e2bfa03664460ef86"), Seq(BlindedNode(PublicKey(hex"03bce3f0cdb4172caac82ec8a9251eb35df1201bdcb977c5a03f3624ec4156a65f"), hex"c0ffee")))), TestCase(hex"022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73 0379a3b6e4bceb7519d09db776994b1f82cf6a9fa4d3ec2e52314c5938f2f9f966 01 02b446aaa523df82a992ab468e5298eabb6168e2c466455c210d8c97dbb8981328 0002 cafe", - BlindedPath(BlindedRoute(PublicKey(hex"022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73"), PublicKey(hex"0379a3b6e4bceb7519d09db776994b1f82cf6a9fa4d3ec2e52314c5938f2f9f966"), Seq(BlindedNode(PublicKey(hex"02b446aaa523df82a992ab468e5298eabb6168e2c466455c210d8c97dbb8981328"), hex"cafe"))))), + BlindedRoute(EncodedNodeId.Plain(PublicKey(hex"022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73")), PublicKey(hex"0379a3b6e4bceb7519d09db776994b1f82cf6a9fa4d3ec2e52314c5938f2f9f966"), Seq(BlindedNode(PublicKey(hex"02b446aaa523df82a992ab468e5298eabb6168e2c466455c210d8c97dbb8981328"), hex"cafe")))), TestCase(hex"03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922 028aa5d1a10463d598a0a0ab7296af21619049f94fe03ef664a87561009e58c3dd 01 02988d7381d0434cfebbe521031505fb9987ae6cefd0bab0e5927852eb96bb6cc2 0003 ec1a13", - BlindedPath(BlindedRoute(PublicKey(hex"03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922"), PublicKey(hex"028aa5d1a10463d598a0a0ab7296af21619049f94fe03ef664a87561009e58c3dd"), Seq(BlindedNode(PublicKey(hex"02988d7381d0434cfebbe521031505fb9987ae6cefd0bab0e5927852eb96bb6cc2"), hex"ec1a13"))))), + BlindedRoute(EncodedNodeId.Plain(PublicKey(hex"03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922")), PublicKey(hex"028aa5d1a10463d598a0a0ab7296af21619049f94fe03ef664a87561009e58c3dd"), Seq(BlindedNode(PublicKey(hex"02988d7381d0434cfebbe521031505fb9987ae6cefd0bab0e5927852eb96bb6cc2"), hex"ec1a13")))), ) testCases.foreach { case TestCase(encoded, decoded) => - assert(OfferCodecs.pathCodec.encode(decoded).require.bytes == encoded) - assert(OfferCodecs.pathCodec.decode(encoded.bits).require.value == decoded) + assert(OfferCodecs.blindedRouteCodec.encode(decoded).require.bytes == encoded) + assert(OfferCodecs.blindedRouteCodec.decode(encoded.bits).require.value == decoded) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index c3a32d30dc..a7e4fe2788 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -20,13 +20,14 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.UInt64.Conversions._ import fr.acinq.eclair.crypto.Sphinx.RouteBlinding +import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop -import fr.acinq.eclair.payment.PaymentBlindedContactInfo +import fr.acinq.eclair.payment.PaymentBlindedRoute import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.PaymentOnion._ import fr.acinq.eclair.wire.protocol.PaymentOnionCodecs._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshiLong, EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, UInt64, randomKey} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -166,12 +167,12 @@ class PaymentOnionSpec extends AnyFunSuite { test("encode/decode node relay to blinded paths per-hop payload") { val features = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional).toByteVector - val blindedRoute = OfferTypes.CompactBlindedPath( + val blindedRoute = BlindedRoute( EncodedNodeId.ShortChannelIdDir(isNode1 = false, RealShortChannelId(468)), PublicKey(hex"0232882c4982576e00f0d6bd4998f5b3e92d47ecc8fbad5b6a5e7521819d891d9e"), Seq(RouteBlinding.BlindedNode(PublicKey(hex"03823aa560d631e9d7b686be4a9227e577009afb5173023b458a6a6aff056ac980"), hex"")) ) - val path = PaymentBlindedContactInfo(blindedRoute, OfferTypes.PaymentInfo(1000 msat, 678, CltvExpiryDelta(82), 300 msat, 4000000 msat, Features.empty)) + val path = PaymentBlindedRoute(blindedRoute, OfferTypes.PaymentInfo(1000 msat, 678, CltvExpiryDelta(82), 300 msat, 4000000 msat, Features.empty)) val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(341 msat), OutgoingCltv(CltvExpiry(826483)), OutgoingBlindedPaths(Seq(path)), InvoiceFeatures(features)) val bin = hex"82 02020155 04030c9c73 fe0001023103020000 fe000102366a0100000000000001d40232882c4982576e00f0d6bd4998f5b3e92d47ecc8fbad5b6a5e7521819d891d9e0103823aa560d631e9d7b686be4a9227e577009afb5173023b458a6a6aff056ac9800000000003e8000002a60052000000000000012c00000000003d09000000" diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/serde/FormParamExtractors.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/serde/FormParamExtractors.scala index 1c934fa1aa..e318a82c66 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/serde/FormParamExtractors.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/serde/FormParamExtractors.scala @@ -26,7 +26,7 @@ import fr.acinq.eclair.blockchain.fee.{ConfirmationPriority, FeeratePerByte} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.NodeURI import fr.acinq.eclair.payment.Bolt11Invoice -import fr.acinq.eclair.wire.protocol.MessageOnionCodecs.blindedRouteCodec +import fr.acinq.eclair.wire.protocol.OfferCodecs.blindedRouteCodec import fr.acinq.eclair.wire.protocol.OfferTypes.Offer import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, TimestampSecond} import scodec.bits.ByteVector diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index e69a06c798..d73bf4b175 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -1240,7 +1240,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM GenericTlv(UInt64(5), hex"1111") )), TlvStream(RouteBlindingEncryptedDataTlv.PathId(hex"2222"))) val msgrcv = OnionMessages.ReceiveMessage(payload) - val expectedSerializedMsgrcv = """{"type":"onion-message-received","pathId":"2222","tlvs":{"EncryptedData":{"data":""},"ReplyPath":{"blindedRoute":{"introductionNodeId":"039dc0e0b1d25905e44fdf6f8e89755a5e219685840d0bc1d28d3308f9628a3585","blindingKey":"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619","blindedNodes":[{"blindedPublicKey":"020303f91e620504cde242df38d04599d8b4d4c555149cc742a5f12de452cbdd40","encryptedPayload":"126a26221759247584d704b382a5789f1d8c5a"}]}},"Unknown5":"1111"}}""" + val expectedSerializedMsgrcv = """{"type":"onion-message-received","pathId":"2222","tlvs":{"EncryptedData":{"data":""},"ReplyPath":{"blindedRoute":{"introductionNodeId":{"publicKey":"039dc0e0b1d25905e44fdf6f8e89755a5e219685840d0bc1d28d3308f9628a3585"},"blindingKey":"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619","blindedNodes":[{"blindedPublicKey":"020303f91e620504cde242df38d04599d8b4d4c555149cc742a5f12de452cbdd40","encryptedPayload":"126a26221759247584d704b382a5789f1d8c5a"}]}},"Unknown5":"1111"}}""" assert(serialization.write(msgrcv) == expectedSerializedMsgrcv) system.eventStream.publish(msgrcv) wsClient.expectMessage(expectedSerializedMsgrcv)