Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relay onion messages to compact node id #2821

Merged
merged 5 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/EncodedNodeId.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package fr.acinq.eclair

import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey

sealed trait EncodedNodeId

object EncodedNodeId {
/** Nodes are usually identified by their public key. */
case class Plain(publicKey: PublicKey) extends EncodedNodeId {
t-bast marked this conversation as resolved.
Show resolved Hide resolved
override def toString: String = publicKey.toString
}

/** For compactness, nodes may be identified by the shortChannelId of one of their public channels. */
case class ShortChannelIdDir(isNode1: Boolean, scid: RealShortChannelId) extends EncodedNodeId {
override def toString: String = if (isNode1) s"<-$scid" else s"$scid->"
}

def apply(publicKey: PublicKey): EncodedNodeId = Plain(publicKey)
}
2 changes: 1 addition & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class Setup(val datadir: File,
txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcher, bitcoinClient)
channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, bitcoinClient, txPublisherFactory)
pendingChannelsRateLimiter = system.spawn(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, channels)).onFailure(typed.SupervisorStrategy.resume), name = "pending-channels-rate-limiter")
peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register)
peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register, router.toTyped)

switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume))
_ = switchboard ! Switchboard.Init(channels)
Expand Down
123 changes: 80 additions & 43 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@
package fr.acinq.eclair.io

import akka.actor.typed.Behavior
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.{ActorRef, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.io.Peer.{PeerInfo, PeerInfoResponse}
import fr.acinq.eclair.io.Switchboard.GetPeerInfo
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.message.OnionMessages.DropReason
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.OnionMessage
import fr.acinq.eclair.{EncodedNodeId, NodeParams, ShortChannelId}

object MessageRelay {
// @formatter:off
sealed trait Command
case class RelayMessage(messageId: ByteVector32,
switchboard: ActorRef,
register: ActorRef,
prevNodeId: PublicKey,
nextNode: Either[ShortChannelId, PublicKey],
nextNode: Either[ShortChannelId, EncodedNodeId],
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
Expand All @@ -60,66 +62,101 @@ object MessageRelay {
case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown outgoing channel: $outgoingChannelId"
}
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
override def toString: String = s"Message dropped: $reason"
}

sealed trait RelayPolicy
case object RelayChannelsOnly extends RelayPolicy
case object RelayAll extends RelayPolicy
// @formatter:on

def apply(): Behavior[Command] = {
Behaviors.receivePartial {
case (context, RelayMessage(messageId, switchboard, register, prevNodeId, Left(outgoingChannelId), msg, policy, replyTo_opt)) =>
def apply(nodeParams: NodeParams,
switchboard: ActorRef,
register: ActorRef,
router: typed.ActorRef[Router.GetNodeId]): Behavior[Command] = {
Behaviors.setup { context =>
Behaviors.receiveMessagePartial {
case RelayMessage(messageId, prevNodeId, nextNode, msg, policy, replyTo_opt) =>
val relay = new MessageRelay(nodeParams, messageId, prevNodeId, policy, switchboard, register, router, replyTo_opt, context)
relay.queryNextNodeId(msg, nextNode)
}
}
}
}

private class MessageRelay(nodeParams: NodeParams,
messageId: ByteVector32,
prevNodeId: PublicKey,
policy: MessageRelay.RelayPolicy,
switchboard: ActorRef,
register: ActorRef,
router: typed.ActorRef[Router.GetNodeId],
replyTo_opt: Option[typed.ActorRef[MessageRelay.Status]],
context: ActorContext[MessageRelay.Command]) {

import MessageRelay._

def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
nextNode match {
case Left(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(messageId, switchboard, prevNodeId, outgoingChannelId, msg, policy, replyTo_opt)
case (context, RelayMessage(messageId, switchboard, _, prevNodeId, Right(nextNodeId), msg, policy, replyTo_opt)) =>
withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)
waitForNextNodeId(msg, outgoingChannelId)
case Right(EncodedNodeId.ShortChannelIdDir(isNode1, scid)) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1)
waitForNextNodeId(msg, scid)
case Right(EncodedNodeId.Plain(nextNodeId)) =>
withNextNodeId(msg, nextNodeId)
}
}

def waitForNextNodeId(messageId: ByteVector32,
switchboard: ActorRef,
prevNodeId: PublicKey,
outgoingChannelId: ShortChannelId,
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] =
Behaviors.receivePartial {
case (_, WrappedOptionalNodeId(None)) =>
private def waitForNextNodeId(msg: OnionMessage, outgoingChannelId: ShortChannelId): Behavior[Command] =
Behaviors.receiveMessagePartial {
case WrappedOptionalNodeId(None) =>
replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId))
Behaviors.stopped
case (context, WrappedOptionalNodeId(Some(nextNodeId))) =>
withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)
case WrappedOptionalNodeId(Some(nextNodeId)) =>
withNextNodeId(msg, nextNodeId)
}

def withNextNodeId(context: ActorContext[Command],
messageId: ByteVector32,
switchboard: ActorRef,
prevNodeId: PublicKey,
nextNodeId: PublicKey,
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] =
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeer(messageId, switchboard, nextNodeId, msg, replyTo_opt)
case RelayAll =>
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(messageId, msg, replyTo_opt)
}
private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
if (nextNodeId == nodeParams.nodeId) {
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
Behaviors.stopped
case OnionMessages.SendMessage(nextNode, nextMessage) =>
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
queryNextNodeId(nextMessage, nextNode)
case received: OnionMessages.ReceiveMessage =>
context.system.eventStream ! EventStream.Publish(received)
replyTo_opt.foreach(_ ! Sent(messageId))
Behaviors.stopped
}
} else {
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeerForPolicyCheck(msg, nextNodeId)
case RelayAll =>
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(msg)
}
}
}

def waitForPreviousPeer(messageId: ByteVector32, switchboard: ActorRef, nextNodeId: PublicKey, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
Behaviors.receivePartial {
case (context, WrappedPeerInfo(PeerInfo(_, _, _, _, channels))) if channels.nonEmpty =>
private def waitForPreviousPeerForPolicyCheck(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerInfo(PeerInfo(_, _, _, _, channels)) if channels.nonEmpty =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), nextNodeId)
waitForNextPeer(messageId, msg, replyTo_opt)
waitForNextPeerForPolicyCheck(msg)
case _ =>
replyTo_opt.foreach(_ ! AgainstPolicy(messageId, RelayChannelsOnly))
Behaviors.stopped
}
}

def waitForNextPeer(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
private def waitForNextPeerForPolicyCheck(msg: OnionMessage): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerInfo(PeerInfo(peer, _, _, _, channels)) if channels.nonEmpty =>
peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Expand All @@ -130,7 +167,7 @@ object MessageRelay {
}
}

def waitForConnection(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
private def waitForConnection(msg: OnionMessage): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedConnectionResult(r: PeerConnection.ConnectionResult.HasConnection) =>
r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Expand Down
17 changes: 13 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import fr.acinq.eclair.io.OpenChannelInterceptor.{OpenChannelInitiator, OpenChan
import fr.acinq.eclair.io.PeerConnection.KillReason
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, RoutingMessage, UnknownMessage, Warning}

Expand All @@ -51,7 +52,14 @@ import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId
*
* Created by PM on 26/08/2016.
*/
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
class Peer(val nodeParams: NodeParams,
remoteNodeId: PublicKey,
wallet: OnchainPubkeyCache,
channelFactory: Peer.ChannelFactory,
switchboard: ActorRef,
register: ActorRef,
router: typed.ActorRef[Router.GetNodeId],
pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {

import Peer._

Expand Down Expand Up @@ -279,8 +287,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP
log.debug("dropping message from {}: {}", remoteNodeId.value.toHex, reason.toString)
case OnionMessages.SendMessage(nextNode, message) if nodeParams.features.hasFeature(Features.OnionMessages) =>
val messageId = randomBytes32()
val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, switchboard, register, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None)
val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None)
case OnionMessages.SendMessage(_, _) =>
log.debug("dropping message from {}: relaying onion messages is disabled", remoteNodeId.value.toHex)
case received: OnionMessages.ReceiveMessage =>
Expand Down Expand Up @@ -458,7 +466,8 @@ object Peer {
context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, txPublisherFactory))
}

def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, pendingChannelsRateLimiter))
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, router: typed.ActorRef[Router.GetNodeId], pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props =
Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, router, pendingChannelsRateLimiter))

// @formatter:off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import fr.acinq.eclair.channel._
import fr.acinq.eclair.io.IncomingConnectionsTracker.TrackIncomingConnection
import fr.acinq.eclair.io.Peer.{PeerInfoResponse, PeerNotFound}
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.router.Router.RouterConf
import fr.acinq.eclair.{NodeParams, SubscriptionsComplete}

Expand Down Expand Up @@ -159,9 +160,9 @@ object Switchboard {
def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef
}

case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef) extends PeerFactory {
case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef, router: typed.ActorRef[Router.GetNodeId]) extends PeerFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef =
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId))
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, router, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId))
}

def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package fr.acinq.eclair.message

import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.{EncodedNodeId, ShortChannelId}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.io.MessageRelay.RelayPolicy
import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, IntermediatePayload}
Expand Down Expand Up @@ -105,9 +105,9 @@ object OnionMessages {
case Left(_) => None
case Right(decoded) =>
decoded.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId] match {
case None => None
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId)) =>
case Some(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.Plain(nextNodeId))) =>
Some(Sphinx.RouteBlinding.BlindedRoute(nextNodeId, decoded.nextBlinding, route.blindedNodes.tail))
case _ => None // TODO: allow compact node id and OutgoingChannelId
}
}
case BlindedPath(route) if intermediateNodes.isEmpty => Some(route)
Expand Down Expand Up @@ -165,7 +165,7 @@ object OnionMessages {
// @formatter:off
sealed trait Action
case class DropMessage(reason: DropReason) extends Action
case class SendMessage(nextNode: Either[ShortChannelId, PublicKey], message: OnionMessage) extends Action
case class SendMessage(nextNode: Either[ShortChannelId, EncodedNodeId], message: OnionMessage) extends Action
case class ReceiveMessage(finalPayload: FinalPayload) extends Action

sealed trait DropReason
Expand Down Expand Up @@ -211,8 +211,8 @@ object OnionMessages {
case Left(f) => DropMessage(f)
case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match {
case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match {
case SendMessage(Right(nextNodeId), nextMsg) if nextNodeId == privateKey.publicKey => process(privateKey, nextMsg)
case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg)
case SendMessage(Right(EncodedNodeId.Plain(publicKey)), nextMsg) if publicKey == privateKey.publicKey => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) // TODO: remove and rely on MessageRelay
case action => action
}
case None => validateFinalPayload(payload, blindedPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, Messag
import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload}
import fr.acinq.eclair.wire.protocol.OfferTypes.{CompactBlindedPath, ContactInfo}
import fr.acinq.eclair.wire.protocol.{OfferTypes, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey}
import fr.acinq.eclair.{EncodedNodeId, NodeParams, randomBytes32, randomKey}

import scala.collection.mutable

Expand Down Expand Up @@ -214,8 +214,8 @@ private class SendingMessage(nodeParams: NodeParams,
replyTo ! Postman.MessageFailed(failure.toString)
Behaviors.stopped
case Right((nextNodeId, message)) =>
val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, switchboard, register, nodeParams.nodeId, Right(nextNodeId), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus)))
val relay = context.spawn(Behaviors.supervise(MessageRelay(nodeParams, switchboard, register, router)).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, nodeParams.nodeId, Right(EncodedNodeId(nextNodeId)), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus)))
waitForSent()
}
}
Expand Down
Loading
Loading