From bb004f9eba74d8981fcc51fbb241d4c31cf4190e Mon Sep 17 00:00:00 2001 From: sstone Date: Thu, 6 Jun 2024 11:03:45 +0200 Subject: [PATCH] [WIP] Implement simple taproot channels --- .../kotlin/fr/acinq/lightning/Features.kt | 12 +- .../kotlin/fr/acinq/lightning/NodeParams.kt | 3 +- .../acinq/lightning/channel/ChannelAction.kt | 1 + .../lightning/channel/ChannelFeatures.kt | 5 + .../fr/acinq/lightning/channel/Commitments.kt | 39 +- .../fr/acinq/lightning/channel/Helpers.kt | 35 +- .../acinq/lightning/channel/states/Normal.kt | 2 +- .../serialization/v4/Serialization.kt | 3 + .../acinq/lightning/transactions/Scripts.kt | 116 +++++ .../lightning/transactions/Transactions.kt | 442 +++++++++++++--- .../fr/acinq/lightning/wire/ChannelTlv.kt | 87 ++++ .../acinq/lightning/wire/LightningMessages.kt | 13 +- .../channel/states/NormalTestsCommon.kt | 2 +- .../transactions/AnchorOutputsTestsCommon.kt | 5 +- .../transactions/TransactionsTestsCommon.kt | 479 +++++++++++++++++- 15 files changed, 1123 insertions(+), 121 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt index 86c8235e4..bf54d9fa6 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt @@ -256,6 +256,12 @@ sealed class Feature { override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } + @Serializable + object SimpleTaprootStaging : Feature() { + override val rfcName get() = "option_simple_taproot_staging" + override val mandatory get() = 180 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) + } } @Serializable @@ -337,7 +343,8 @@ data class Features(val activated: Map, val unknown: Se Feature.ChannelBackupClient, Feature.ChannelBackupProvider, Feature.ExperimentalSplice, - Feature.Quiescence + Feature.Quiescence, + Feature.SimpleTaprootStaging ) operator fun invoke(bytes: ByteVector): Features = invoke(bytes.toByteArray()) @@ -369,7 +376,8 @@ data class Features(val activated: Map, val unknown: Se Feature.BasicMultiPartPayment to listOf(Feature.PaymentSecret), Feature.AnchorOutputs to listOf(Feature.StaticRemoteKey), Feature.TrampolinePayment to listOf(Feature.PaymentSecret), - Feature.ExperimentalTrampolinePayment to listOf(Feature.PaymentSecret) + Feature.ExperimentalTrampolinePayment to listOf(Feature.PaymentSecret), + Feature.SimpleTaprootStaging to listOf(Feature.AnchorOutputs, Feature.StaticRemoteKey) ) class FeatureException(message: String) : IllegalArgumentException(message) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt b/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt index abec81b2a..50ed341a1 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt @@ -201,7 +201,8 @@ data class NodeParams( Feature.PayToOpenClient to FeatureSupport.Optional, Feature.ChannelBackupClient to FeatureSupport.Optional, Feature.ExperimentalSplice to FeatureSupport.Optional, - Feature.Quiescence to FeatureSupport.Mandatory + Feature.Quiescence to FeatureSupport.Mandatory, + Feature.SimpleTaprootStaging to FeatureSupport.Optional ), dustLimit = 546.sat, maxRemoteDustLimit = 600.sat, diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelAction.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelAction.kt index 6b2d5aa08..b7b3434ff 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelAction.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelAction.kt @@ -50,6 +50,7 @@ sealed class ChannelAction { is Transactions.TransactionWithInputInfo.CommitTx -> Type.CommitTx is Transactions.TransactionWithInputInfo.HtlcTx.HtlcSuccessTx -> Type.HtlcSuccessTx is Transactions.TransactionWithInputInfo.HtlcTx.HtlcTimeoutTx -> Type.HtlcTimeoutTx + is Transactions.TransactionWithInputInfo.HtlcDelayedTx -> Type.ClaimHtlcTimeoutTx is Transactions.TransactionWithInputInfo.ClaimHtlcTx.ClaimHtlcSuccessTx -> Type.ClaimHtlcSuccessTx is Transactions.TransactionWithInputInfo.ClaimHtlcTx.ClaimHtlcTimeoutTx -> Type.ClaimHtlcTimeoutTx is Transactions.TransactionWithInputInfo.ClaimAnchorOutputTx.ClaimLocalAnchorOutputTx -> Type.ClaimLocalAnchorOutputTx diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelFeatures.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelFeatures.kt index ed359c5c9..ae7f3405b 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelFeatures.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/ChannelFeatures.kt @@ -59,6 +59,10 @@ sealed class ChannelType { override val features: Set get() = setOf(Feature.StaticRemoteKey, Feature.AnchorOutputs, Feature.ZeroReserveChannels) } + object SimpleTaprootStaging : SupportedChannelType() { + override val name: String get() = "simple_taproot_staging" + override val features: Set get() = setOf(Feature.SimpleTaprootStaging, Feature.StaticRemoteKey, Feature.AnchorOutputs, Feature.ZeroReserveChannels) + } } data class UnsupportedChannelType(val featureBits: Features) : ChannelType() { @@ -71,6 +75,7 @@ sealed class ChannelType { // NB: Bolt 2: features must exactly match in order to identify a channel type. fun fromFeatures(features: Features): ChannelType = when (features) { // @formatter:off + Features(Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory, Feature.ZeroReserveChannels to FeatureSupport.Mandatory, Feature.SimpleTaprootStaging to FeatureSupport.Optional) -> SupportedChannelType.SimpleTaprootStaging Features(Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory, Feature.ZeroReserveChannels to FeatureSupport.Mandatory) -> SupportedChannelType.AnchorOutputsZeroReserve Features(Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory) -> SupportedChannelType.AnchorOutputs else -> UnsupportedChannelType(features) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt index c1b728c9b..817fc7371 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt @@ -2,10 +2,12 @@ package fr.acinq.lightning.channel import fr.acinq.bitcoin.* import fr.acinq.bitcoin.Crypto.sha256 +import fr.acinq.bitcoin.crypto.musig2.IndividualNonce import fr.acinq.bitcoin.utils.Either import fr.acinq.bitcoin.utils.Try import fr.acinq.lightning.CltvExpiryDelta import fr.acinq.lightning.Feature +import fr.acinq.lightning.Features import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.blockchain.fee.FeeratePerByte import fr.acinq.lightning.blockchain.fee.FeeratePerKw @@ -46,6 +48,8 @@ data class ChannelParams( require(channelConfig.hasOption(ChannelConfigOption.FundingPubKeyBasedChannelKeyPath)) { "FundingPubKeyBasedChannelKeyPath option must be enabled" } } + val isTaprootChannel = Features.canUseFeature(localParams.features, remoteParams.features, Feature.SimpleTaprootStaging) + fun updateFeatures(localInit: Init, remoteInit: Init) = this.copy( localParams = localParams.copy(features = localInit.features), remoteParams = remoteParams.copy(features = remoteInit.features) @@ -94,6 +98,7 @@ data class CommitmentChanges(val localChanges: LocalChanges, val remoteChanges: data class HtlcTxAndSigs(val txinfo: HtlcTx, val localSig: ByteVector64, val remoteSig: ByteVector64) data class PublishableTxs(val commitTx: CommitTx, val htlcTxsAndSigs: List) +data class PartialSignatureWithNonce(val partialSig: ByteVector32, val nonce: IndividualNonce) /** The local commitment maps to a commitment transaction that we can sign and broadcast if necessary. */ data class LocalCommit(val index: Long, val spec: CommitmentSpec, val publishableTxs: PublishableTxs) { @@ -254,9 +259,9 @@ data class Commitment( val balanceNoFees = (reduced.toRemote - localChannelReserve(params).toMilliSatoshi()).coerceAtLeast(0.msat) return if (params.localParams.isInitiator) { // The initiator always pays the on-chain fees, so we must subtract that from the amount we can send. - val commitFees = commitTxFeeMsat(params.remoteParams.dustLimit, reduced) + val commitFees = commitTxFeeMsat(params.remoteParams.dustLimit, reduced, params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) // the initiator needs to keep a "initiator fee buffer" (see explanation above) - val initiatorFeeBuffer = commitTxFeeMsat(params.remoteParams.dustLimit, reduced.copy(feerate = reduced.feerate * 2)) + htlcOutputFee(reduced.feerate * 2) + val initiatorFeeBuffer = commitTxFeeMsat(params.remoteParams.dustLimit, reduced.copy(feerate = reduced.feerate * 2), params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) + htlcOutputFee(reduced.feerate * 2) val amountToReserve = commitFees.coerceAtLeast(initiatorFeeBuffer) if (balanceNoFees - amountToReserve < offeredHtlcTrimThreshold(params.remoteParams.dustLimit, reduced).toMilliSatoshi()) { // htlc will be trimmed @@ -283,9 +288,9 @@ data class Commitment( balanceNoFees } else { // The initiator always pays the on-chain fees, so we must subtract that from the amount we can receive. - val commitFees = commitTxFeeMsat(params.localParams.dustLimit, reduced) + val commitFees = commitTxFeeMsat(params.localParams.dustLimit, reduced, params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) // we expected the initiator to keep a "initiator fee buffer" (see explanation above) - val initiatorFeeBuffer = commitTxFeeMsat(params.localParams.dustLimit, reduced.copy(feerate = reduced.feerate * 2)) + htlcOutputFee(reduced.feerate * 2) + val initiatorFeeBuffer = commitTxFeeMsat(params.localParams.dustLimit, reduced.copy(feerate = reduced.feerate * 2), params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) + htlcOutputFee(reduced.feerate * 2) val amountToReserve = commitFees.coerceAtLeast(initiatorFeeBuffer) if (balanceNoFees - amountToReserve < receivedHtlcTrimThreshold(params.localParams.dustLimit, reduced).toMilliSatoshi()) { // htlc will be trimmed @@ -351,10 +356,10 @@ data class Commitment( val outgoingHtlcs = reduced.htlcs.incomings() // note that the initiator pays the fee, so if sender != initiator, both sides will have to afford this payment - val fees = commitTxFee(params.remoteParams.dustLimit, reduced) + val fees = commitTxFee(params.remoteParams.dustLimit, reduced, params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) // the initiator needs to keep an extra buffer to be able to handle a x2 feerate increase and an additional htlc to avoid // getting the channel stuck (see https://github.com/lightningnetwork/lightning-rfc/issues/728). - val initiatorFeeBuffer = commitTxFeeMsat(params.remoteParams.dustLimit, reduced.copy(feerate = reduced.feerate * 2)) + htlcOutputFee(reduced.feerate * 2) + val initiatorFeeBuffer = commitTxFeeMsat(params.remoteParams.dustLimit, reduced.copy(feerate = reduced.feerate * 2), params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) + htlcOutputFee(reduced.feerate * 2) // NB: increasing the feerate can actually remove htlcs from the commit tx (if they fall below the trim threshold) // which may result in a lower commit tx fee; this is why we take the max of the two. val missingForSender = reduced.toRemote - localChannelReserve(params).toMilliSatoshi() - (if (params.localParams.isInitiator) fees.toMilliSatoshi().coerceAtLeast(initiatorFeeBuffer) else 0.msat) @@ -403,7 +408,7 @@ data class Commitment( val incomingHtlcs = reduced.htlcs.incomings() // note that the initiator pays the fee, so if sender != initiator, both sides will have to afford this payment - val fees = commitTxFee(params.localParams.dustLimit, reduced) + val fees = commitTxFee(params.localParams.dustLimit, reduced, params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) // NB: we don't enforce the initiatorFeeReserve (see sendAdd) because it would confuse a remote initiator that doesn't have this mitigation in place // We could enforce it once we're confident a large portion of the network implements it. val missingForSender = reduced.toRemote - remoteChannelReserve(params).toMilliSatoshi() - (if (params.localParams.isInitiator) 0.sat else fees).toMilliSatoshi() @@ -436,7 +441,7 @@ data class Commitment( val reduced = CommitmentSpec.reduce(remoteCommit.spec, changes.remoteChanges.acked, changes.localChanges.proposed) // a node cannot spend pending incoming htlcs, and need to keep funds above the reserve required by the counterparty, after paying the fee // we look from remote's point of view, so if local is initiator remote doesn't pay the fees - val fees = commitTxFee(params.remoteParams.dustLimit, reduced) + val fees = commitTxFee(params.remoteParams.dustLimit, reduced, params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) val missing = reduced.toRemote.truncateToSatoshi() - localChannelReserve(params) - fees return if (missing < 0.sat) { Either.Left(CannotAffordFees(params.channelId, -missing, localChannelReserve(params), fees)) @@ -453,7 +458,7 @@ data class Commitment( // It is easier to do it here because under certain (race) conditions spec allows a lower-than-normal fee to be paid, // and it would be tricky to check if the conditions are met at signing // (it also means that we need to check the fee of the initial commitment tx somewhere) - val fees = commitTxFee(params.localParams.dustLimit, reduced) + val fees = commitTxFee(params.localParams.dustLimit, reduced, params.channelFeatures.hasFeature(Feature.SimpleTaprootStaging)) val missing = reduced.toRemote.truncateToSatoshi() - remoteChannelReserve(params) - fees return if (missing < 0.sat) { Either.Left(CannotAffordFees(params.channelId, -missing, remoteChannelReserve(params), fees)) @@ -554,6 +559,7 @@ data class FullCommitment( params.channelFeatures.hasFeature(Feature.ZeroReserveChannels) -> 0.sat else -> (fundingAmount / 100).max(params.localParams.dustLimit) } + val isTaprootChannel = params.isTaprootChannel } data class WaitingForRevocation(val sentAfterLocalCommitIndex: Long) @@ -575,6 +581,7 @@ data class Commitments( val channelId: ByteVector32 = params.channelId val localNodeId: PublicKey = params.localParams.nodeId val remoteNodeId: PublicKey = params.remoteParams.nodeId + val isTaprootChannel = params.isTaprootChannel // Commitment numbers are the same for all active commitments. val localCommitIndex = active.first().localCommit.index @@ -982,6 +989,7 @@ data class Commitments( val ANCHOR_AMOUNT = 330.sat const val COMMIT_WEIGHT = 1124 + const val COMMIT_WEIGHT_TAPROOT = 968 const val HTLC_OUTPUT_WEIGHT = 172 const val HTLC_TIMEOUT_WEIGHT = 666 const val HTLC_SUCCESS_WEIGHT = 706 @@ -1028,6 +1036,7 @@ data class Commitments( val remoteHtlcPubkey = remoteParams.htlcBasepoint.deriveForCommitment(localPerCommitmentPoint) val localRevocationPubkey = remoteParams.revocationBasepoint.deriveForRevocation(localPerCommitmentPoint) val localPaymentBasepoint = channelKeys.paymentBasepoint + val isTaprootChannel = Features.canUseFeature(localParams.features, remoteParams.features, Feature.SimpleTaprootStaging) val outputs = makeCommitTxOutputs( channelKeys.fundingPubKey(fundingTxIndex), remoteFundingPubKey, @@ -1039,10 +1048,12 @@ data class Commitments( remotePaymentPubkey, localHtlcPubkey, remoteHtlcPubkey, - spec + spec, + isTaprootChannel ) + val commitTx = Transactions.makeCommitTx(commitmentInput, commitTxNumber, localPaymentBasepoint, remoteParams.paymentBasepoint, localParams.isInitiator, outputs) - val htlcTxs = Transactions.makeHtlcTxs(commitTx.tx, localParams.dustLimit, localRevocationPubkey, remoteParams.toSelfDelay, localDelayedPaymentPubkey, spec.feerate, outputs) + val htlcTxs = Transactions.makeHtlcTxs(commitTx.tx, localParams.dustLimit, localRevocationPubkey, remoteParams.toSelfDelay, localDelayedPaymentPubkey, spec.feerate, outputs, isTaprootChannel) return Pair(commitTx, htlcTxs) } @@ -1062,6 +1073,7 @@ data class Commitments( val remoteDelayedPaymentPubkey = remoteParams.delayedPaymentBasepoint.deriveForCommitment(remotePerCommitmentPoint) val remoteHtlcPubkey = remoteParams.htlcBasepoint.deriveForCommitment(remotePerCommitmentPoint) val remoteRevocationPubkey = channelKeys.revocationBasepoint.deriveForRevocation(remotePerCommitmentPoint) + val isTaprootChannel = Features.canUseFeature(localParams.features, remoteParams.features, Feature.SimpleTaprootStaging) val outputs = makeCommitTxOutputs( remoteFundingPubKey, channelKeys.fundingPubKey(fundingTxIndex), @@ -1073,11 +1085,12 @@ data class Commitments( localPaymentPubkey, remoteHtlcPubkey, localHtlcPubkey, - spec + spec, + isTaprootChannel ) // NB: we are creating the remote commit tx, so local/remote parameters are inverted. val commitTx = Transactions.makeCommitTx(commitmentInput, commitTxNumber, remoteParams.paymentBasepoint, localPaymentPubkey, !localParams.isInitiator, outputs) - val htlcTxs = Transactions.makeHtlcTxs(commitTx.tx, remoteParams.dustLimit, remoteRevocationPubkey, localParams.toSelfDelay, remoteDelayedPaymentPubkey, spec.feerate, outputs) + val htlcTxs = Transactions.makeHtlcTxs(commitTx.tx, remoteParams.dustLimit, remoteRevocationPubkey, localParams.toSelfDelay, remoteDelayedPaymentPubkey, spec.feerate, outputs, isTaprootChannel) return Pair(commitTx, htlcTxs) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt index c37a4e884..ed9ce4b42 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt @@ -30,6 +30,7 @@ import fr.acinq.lightning.crypto.ShaChain import fr.acinq.lightning.logging.* import fr.acinq.lightning.transactions.* import fr.acinq.lightning.transactions.Scripts.multiSig2of2 +import fr.acinq.lightning.transactions.Scripts.musig2FundingScript import fr.acinq.lightning.transactions.Transactions.TransactionWithInputInfo.ClaimHtlcDelayedOutputPenaltyTx import fr.acinq.lightning.transactions.Transactions.TransactionWithInputInfo.ClaimHtlcTx.ClaimHtlcTimeoutTx import fr.acinq.lightning.transactions.Transactions.TransactionWithInputInfo.ClosingTx @@ -219,15 +220,18 @@ object Helpers { fundingTxOutputIndex: Int, fundingAmount: Satoshi, fundingPubkey1: PublicKey, - fundingPubkey2: PublicKey + fundingPubkey2: PublicKey, + isTaprootChannel: Boolean ): Transactions.InputInfo { - val fundingScript = multiSig2of2(fundingPubkey1, fundingPubkey2) - val fundingTxOut = TxOut(fundingAmount, pay2wsh(fundingScript)) - return Transactions.InputInfo( - OutPoint(fundingTxId, fundingTxOutputIndex.toLong()), - fundingTxOut, - ByteVector(write(fundingScript)) - ) + if (isTaprootChannel) { + val fundingScript = musig2FundingScript(fundingPubkey1, fundingPubkey2) + val fundingTxOut = TxOut(fundingAmount, fundingScript) + return Transactions.InputInfo(OutPoint(fundingTxId, fundingTxOutputIndex.toLong()), fundingTxOut, ByteVector(write(fundingScript))) + } else { + val fundingScript = multiSig2of2(fundingPubkey1, fundingPubkey2) + val fundingTxOut = TxOut(fundingAmount, pay2wsh(fundingScript)) + return Transactions.InputInfo(OutPoint(fundingTxId, fundingTxOutputIndex.toLong()), fundingTxOut, ByteVector(write(fundingScript))) + } } data class PairOfCommitTxs(val localSpec: CommitmentSpec, val localCommitTx: Transactions.TransactionWithInputInfo.CommitTx, val localHtlcTxs: List, val remoteSpec: CommitmentSpec, val remoteCommitTx: Transactions.TransactionWithInputInfo.CommitTx, val remoteHtlcTxs: List) @@ -257,13 +261,13 @@ object Helpers { ): Either { val localSpec = CommitmentSpec(localHtlcs, commitTxFeerate, toLocal = toLocal, toRemote = toRemote) val remoteSpec = CommitmentSpec(localHtlcs.map{ it.opposite() }.toSet(), commitTxFeerate, toLocal = toRemote, toRemote = toLocal) - + val isTaprootChannel = Features.canUseFeature(localParams.features, remoteParams.features, Feature.SimpleTaprootStaging) if (!localParams.isInitiator) { // They initiated the channel open, therefore they pay the fee: we need to make sure they can afford it! // Note that the reserve may not be always be met: we could be using dual funding with a large funding amount on // our side and a small funding amount on their side. But we shouldn't care as long as they can pay the fees for // the commitment transaction. - val fees = commitTxFee(remoteParams.dustLimit, remoteSpec) + val fees = commitTxFee(remoteParams.dustLimit, remoteSpec, isTaprootChannel) val missing = fees - remoteSpec.toLocal.truncateToSatoshi() if (missing > 0.sat) { return Either.Left(CannotAffordFirstCommitFees(channelId, missing = missing, fees = fees)) @@ -271,7 +275,7 @@ object Helpers { } val fundingPubKey = channelKeys.fundingPubKey(fundingTxIndex) - val commitmentInput = makeFundingInputInfo(fundingTxId, fundingTxOutputIndex, fundingAmount, fundingPubKey, remoteFundingPubkey) + val commitmentInput = makeFundingInputInfo(fundingTxId, fundingTxOutputIndex, fundingAmount, fundingPubKey, remoteFundingPubkey, isTaprootChannel) val localPerCommitmentPoint = channelKeys.commitmentPoint(localCommitmentIndex) val (localCommitTx, localHtlcTxs) = Commitments.makeLocalTxs( channelKeys, @@ -425,7 +429,8 @@ object Helpers { commitment.params.remoteParams.toSelfDelay, localDelayedPubkey, localParams.defaultFinalScriptPubKey.toByteArray(), - feerateDelayed + feerateDelayed, + commitment.isTaprootChannel ) }?.let { val sig = Transactions.sign(it, channelKeys.delayedPaymentKey.deriveForCommitment(localPerCommitmentPoint), SigHash.SIGHASH_ALL) @@ -459,7 +464,8 @@ object Helpers { commitment.params.remoteParams.toSelfDelay, localDelayedPubkey, localParams.defaultFinalScriptPubKey.toByteArray(), - feerateDelayed + feerateDelayed, + commitment.isTaprootChannel ) }?.let { val sig = Transactions.sign(it, channelKeys.delayedPaymentKey.deriveForCommitment(localPerCommitmentPoint), SigHash.SIGHASH_ALL) @@ -518,7 +524,8 @@ object Helpers { localPaymentPubkey, remoteHtlcPubkey, localHtlcPubkey, - remoteCommit.spec + remoteCommit.spec, + commitment.isTaprootChannel ) // we need to use a rather high fee for htlc-claim because we compete with the counterparty diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/states/Normal.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/states/Normal.kt index c88ec47c9..489ce52f1 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/states/Normal.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/states/Normal.kt @@ -389,7 +389,7 @@ data class Normal( targetFeerate = spliceStatus.command.feerate ) val commitTxFees = when { - commitments.params.localParams.isInitiator -> Transactions.commitTxFee(commitments.params.remoteParams.dustLimit, parentCommitment.remoteCommit.spec) + commitments.params.localParams.isInitiator -> Transactions.commitTxFee(commitments.params.remoteParams.dustLimit, parentCommitment.remoteCommit.spec, commitments.isTaprootChannel) else -> 0.sat } if (parentCommitment.localCommit.spec.toLocal + fundingContribution.toMilliSatoshi() < parentCommitment.localChannelReserve(commitments.params).max(commitTxFees)) { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Serialization.kt index 48aab0206..4a68d6a02 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v4/Serialization.kt @@ -690,6 +690,9 @@ object Serialization { is SpliceTx -> { write(0x0e); writeInputInfo(o.input); writeBtcObject(o.tx) } + is HtlcDelayedTx -> { + write(0x0f); writeInputInfo(o.input); writeBtcObject(o.tx) + } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/transactions/Scripts.kt b/src/commonMain/kotlin/fr/acinq/lightning/transactions/Scripts.kt index 985d09dec..fdb52ea55 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/transactions/Scripts.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/transactions/Scripts.kt @@ -2,8 +2,10 @@ package fr.acinq.lightning.transactions import fr.acinq.bitcoin.* import fr.acinq.bitcoin.ScriptEltMapping.code2elt +import fr.acinq.bitcoin.crypto.musig2.Musig2 import fr.acinq.lightning.CltvExpiry import fr.acinq.lightning.CltvExpiryDelta +import fr.acinq.lightning.transactions.Transactions.NUMS_POINT import fr.acinq.lightning.utils.sat /** @@ -30,6 +32,12 @@ object Scripts { ScriptWitness(listOf(ByteVector.empty, der(sig2, SigHash.SIGHASH_ALL), der(sig1, SigHash.SIGHASH_ALL), ByteVector(Script.write(multiSig2of2(pubkey1, pubkey2))))) } + fun sort(pubkeys: List): List = pubkeys.sortedWith { a, b -> LexicographicalOrdering.compare(a, b) } + + fun musig2Aggregate(pubkey1: PublicKey, pubkey2: PublicKey): XonlyPublicKey = Musig2.aggregateKeys(sort(listOf(pubkey1, pubkey2))) + + fun musig2FundingScript(pubkey1: PublicKey, pubkey2: PublicKey): List = Script.pay2tr(musig2Aggregate(pubkey1, pubkey2), null as ByteVector32?) + /** * minimal encoding of a number into a script element: * - OP_0 to OP_16 if 0 <= n <= 16 @@ -241,4 +249,112 @@ object Scripts { fun witnessHtlcWithRevocationSig(revocationSig: ByteVector64, revocationPubkey: PublicKey, htlcScript: ByteVector) = ScriptWitness(listOf(der(revocationSig, SigHash.SIGHASH_ALL), revocationPubkey.value, htlcScript)) + /** + * Specific scripts for taproot channels + */ + object Taproot { + val anchorScript: List = listOf(OP_16, OP_CHECKSEQUENCEVERIFY) + + val anchorScriptTree = ScriptTree.Leaf(0, anchorScript) + + fun toRevokeScript(revocationPubkey: PublicKey, localDelayedPaymentPubkey: PublicKey) = + listOf(OP_PUSHDATA(localDelayedPaymentPubkey.xOnly()), OP_DROP, OP_PUSHDATA(revocationPubkey.xOnly()), OP_CHECKSIG) + + fun toDelayScript(localDelayedPaymentPubkey: PublicKey, toLocalDelay: CltvExpiryDelta) = + listOf(OP_PUSHDATA(localDelayedPaymentPubkey.xOnly()), OP_CHECKSIG, encodeNumber(toLocalDelay.toLong()), OP_CHECKSEQUENCEVERIFY, OP_DROP) + + /** + * Taproot channels to-local key, used for the delayed to-local output + * + * @param revocationPubkey revocation key + * @param toSelfDelay self CsV delay + * @param localDelayedPaymentPubkey local delayed payment key + * @return an (XonlyPubkey, Parity) pair + */ + fun toLocalKey(revocationPubkey: PublicKey, toSelfDelay: CltvExpiryDelta, localDelayedPaymentPubkey: PublicKey): Pair { + val revokeScript = toRevokeScript(revocationPubkey, localDelayedPaymentPubkey) + val delayScript = toDelayScript(localDelayedPaymentPubkey, toSelfDelay) + val scriptTree = ScriptTree.Branch( + ScriptTree.Leaf(0, delayScript), + ScriptTree.Leaf(1, revokeScript), + ) + return XonlyPublicKey(NUMS_POINT).outputKey(Crypto.TaprootTweak.ScriptTweak(scriptTree)) + } + + /** + * + * @param revocationPubkey revocation key + * @param toSelfDelay to-self CSV delay + * @param localDelayedPaymentPubkey local delayed payment key + * @return a script tree with two leaves (to self with delay, and to revocation key) + */ + fun toLocalScriptTree(revocationPubkey: PublicKey, toSelfDelay: CltvExpiryDelta, localDelayedPaymentPubkey: PublicKey): ScriptTree.Branch { + return ScriptTree.Branch( + ScriptTree.Leaf(0, toDelayScript(localDelayedPaymentPubkey, toSelfDelay)), + ScriptTree.Leaf(1, toRevokeScript(revocationPubkey, localDelayedPaymentPubkey)), + ) + } + + fun toRemoteScript(remotePaymentPubkey: PublicKey) = listOf(OP_PUSHDATA(remotePaymentPubkey.xOnly()), OP_CHECKSIG, OP_1, OP_CHECKSEQUENCEVERIFY, OP_DROP) + + /** + * taproot channel to-remote key, used for the to-remote output + * + * @param remotePaymentPubkey remote key + * @return a (XonlyPubkey, Parity) pair + */ + fun toRemoteKey(remotePaymentPubkey: PublicKey): Pair { + val remoteScript = toRemoteScript(remotePaymentPubkey) + val scriptTree = ScriptTree.Leaf(0, remoteScript) + return XonlyPublicKey(NUMS_POINT).outputKey(scriptTree) + } + + /** + * + * @param remotePaymentPubkey remote key + * @return a script tree with a single leaf (to remote key, with a 1-block CSV delay) + */ + fun toRemoteScriptTree(remotePaymentPubkey: PublicKey) = ScriptTree.Leaf(0, toRemoteScript(remotePaymentPubkey)) + + fun offeredHtlcTimeoutScript(localHtlcPubkey: PublicKey, remoteHtlcPubkey: PublicKey) = listOf(OP_PUSHDATA(localHtlcPubkey.xOnly()), OP_CHECKSIGVERIFY, OP_PUSHDATA(remoteHtlcPubkey.xOnly()), OP_CHECKSIG) + + fun offeredHtlcSuccessScript(remoteHtlcPubkey: PublicKey, paymentHash: ByteVector32) = listOf( + // @formatter:off + OP_SIZE, encodeNumber(32), OP_EQUALVERIFY, + OP_HASH160, OP_PUSHDATA(Crypto.ripemd160(paymentHash)), OP_EQUALVERIFY, + OP_PUSHDATA(remoteHtlcPubkey.xOnly()), OP_CHECKSIG, + OP_1, OP_CHECKSEQUENCEVERIFY, OP_DROP + // @formatter:on + ) + + fun offeredHtlcTree(localHtlcPubkey: PublicKey, remoteHtlcPubkey: PublicKey, paymentHash: ByteVector32) = + ScriptTree.Branch( + ScriptTree.Leaf(0, offeredHtlcTimeoutScript(localHtlcPubkey, remoteHtlcPubkey)), + ScriptTree.Leaf(1, offeredHtlcSuccessScript(remoteHtlcPubkey, paymentHash)) + ) + + fun receivedHtlcTimeoutScript(remoteHtlcPubkey: PublicKey, lockTime: CltvExpiry) = listOf( + // @formatter:off + OP_PUSHDATA(remoteHtlcPubkey.xOnly()), OP_CHECKSIG, + OP_1, OP_CHECKSEQUENCEVERIFY, OP_DROP, + encodeNumber(lockTime.toLong()), OP_CHECKLOCKTIMEVERIFY, OP_DROP + // @formatter:on + ) + + fun receivedHtlcSuccessScript(localHtlcPubkey: PublicKey, remoteHtlcPubkey: PublicKey, paymentHash: ByteVector32) = listOf( + // @formatter:off + OP_SIZE, encodeNumber(32), OP_EQUALVERIFY, + OP_HASH160, OP_PUSHDATA(Crypto.ripemd160(paymentHash)), OP_EQUALVERIFY, + OP_PUSHDATA(localHtlcPubkey.xOnly()), OP_CHECKSIGVERIFY, + OP_PUSHDATA(remoteHtlcPubkey.xOnly()), OP_CHECKSIG + // @formatter:on + ) + + fun receivedHtlcTree(localHtlcPubkey: PublicKey, remoteHtlcPubkey: PublicKey, paymentHash: ByteVector32, lockTime: CltvExpiry): ScriptTree.Branch { + return ScriptTree.Branch( + ScriptTree.Leaf(0, receivedHtlcTimeoutScript(remoteHtlcPubkey, lockTime)), + ScriptTree.Leaf(1, receivedHtlcSuccessScript(localHtlcPubkey, remoteHtlcPubkey, paymentHash)), + ) + } + } } \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/transactions/Transactions.kt b/src/commonMain/kotlin/fr/acinq/lightning/transactions/Transactions.kt index b826695a4..31d58c496 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/transactions/Transactions.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/transactions/Transactions.kt @@ -18,15 +18,21 @@ package fr.acinq.lightning.transactions import fr.acinq.bitcoin.* import fr.acinq.bitcoin.crypto.Pack +import fr.acinq.bitcoin.crypto.musig2.IndividualNonce +import fr.acinq.bitcoin.crypto.musig2.Musig2 +import fr.acinq.bitcoin.crypto.musig2.SecretNonce +import fr.acinq.bitcoin.utils.Either import fr.acinq.bitcoin.utils.Try import fr.acinq.bitcoin.utils.runTrying import fr.acinq.lightning.CltvExpiryDelta import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.blockchain.fee.FeeratePerKw +import fr.acinq.lightning.channel.ChannelType import fr.acinq.lightning.channel.Commitments import fr.acinq.lightning.io.* import fr.acinq.lightning.transactions.CommitmentOutput.InHtlc import fr.acinq.lightning.transactions.CommitmentOutput.OutHtlc +import fr.acinq.lightning.transactions.Scripts.witnessToLocalDelayedAfterDelay import fr.acinq.lightning.utils.* import fr.acinq.lightning.wire.UpdateAddHtlc import kotlinx.serialization.Contextual @@ -41,12 +47,25 @@ typealias TransactionsCommitmentOutputs = List) : this(outPoint, txOut, ByteVector(Script.write(redeemScript))) } @@ -62,6 +81,12 @@ object Transactions { return (FeeratePerKw.MinimumRelayFeeRate * vsize / 1000).sat } + open fun sign(key: PrivateKey, sigHash: Int = SigHash.SIGHASH_ALL): ByteVector64 { + val inputIndex = tx.txIn.indexOfFirst { it.outPoint == input.outPoint } + require(inputIndex >= 0) { "transaction doesn't spend the input to sign" } + return sign(tx, inputIndex, input.redeemScript.toByteArray(), input.txOut.amount, key, sigHash) + } + @Serializable data class SpliceTx(override val input: InputInfo, @Contextual override val tx: Transaction) : TransactionWithInputInfo() @@ -81,7 +106,30 @@ object Transactions { ) : HtlcTx() @Serializable - data class HtlcTimeoutTx(override val input: InputInfo, @Contextual override val tx: Transaction, override val htlcId: Long) : HtlcTx() + data class HtlcTimeoutTx(override val input: InputInfo, @Contextual override val tx: Transaction, override val htlcId: Long) : HtlcTx() { + override fun sign(key: PrivateKey, sigHash: Int): ByteVector64 { + return when (val tree = input.scriptTree) { + null -> super.sign(key, sigHash) + else -> { + val branch = tree.scriptTree as ScriptTree.Branch + Transaction.signInputTaprootScriptPath(key, tx, 0, listOf(input.txOut), SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY, branch.left.hash()) + } + } + } + } + } + + @Serializable + data class HtlcDelayedTx(override val input: InputInfo, @Contextual override val tx: Transaction) : TransactionWithInputInfo() { + override fun sign(key: PrivateKey, sigHash: Int): ByteVector64 { + return when (input.scriptTree) { + null -> super.sign(key, sigHash) + else -> { + val branch = input.scriptTree.scriptTree as ScriptTree.Leaf + Transaction.signInputTaprootScriptPath(key, tx, 0, listOf(input.txOut), SigHash.SIGHASH_DEFAULT, branch.hash()) + } + } + } } @Serializable @@ -163,6 +211,7 @@ object Transactions { */ // legacy swap-in. witness is 2 signatures (73 bytes) + redeem script (77 bytes) const val swapInputWeightLegacy = 392 + // musig2 swap-in. witness is a single Schnorr signature (64 bytes) const val swapInputWeight = 233 @@ -230,14 +279,18 @@ object Transactions { * If you are adding multiple fees together for example, you should always add them in MilliSatoshi and then round * down to Satoshi. */ - fun commitTxFeeMsat(dustLimit: Satoshi, spec: CommitmentSpec): MilliSatoshi { + fun commitTxFeeMsat(dustLimit: Satoshi, spec: CommitmentSpec, isTaprootChannel: Boolean): MilliSatoshi { val trimmedOfferedHtlcs = trimOfferedHtlcs(dustLimit, spec) val trimmedReceivedHtlcs = trimReceivedHtlcs(dustLimit, spec) - val weight = Commitments.COMMIT_WEIGHT + Commitments.HTLC_OUTPUT_WEIGHT * (trimmedOfferedHtlcs.size + trimmedReceivedHtlcs.size) + val weight = if (isTaprootChannel) { + Commitments.COMMIT_WEIGHT_TAPROOT + Commitments.HTLC_OUTPUT_WEIGHT * (trimmedOfferedHtlcs.size + trimmedReceivedHtlcs.size) + } else { + Commitments.COMMIT_WEIGHT + Commitments.HTLC_OUTPUT_WEIGHT * (trimmedOfferedHtlcs.size + trimmedReceivedHtlcs.size) + } return weight2feeMsat(spec.feerate, weight) + (Commitments.ANCHOR_AMOUNT * 2).toMilliSatoshi() } - fun commitTxFee(dustLimit: Satoshi, spec: CommitmentSpec): Satoshi = commitTxFeeMsat(dustLimit, spec).truncateToSatoshi() + fun commitTxFee(dustLimit: Satoshi, spec: CommitmentSpec, isTaprootChannel: Boolean): Satoshi = commitTxFeeMsat(dustLimit, spec, isTaprootChannel).truncateToSatoshi() /** * @param commitTxNumber commit tx number @@ -290,7 +343,12 @@ object Transactions { * @param redeemScript redeem script that matches this output (most of them are p2wsh) * @param commitmentOutput commitment spec item this output is built from */ - data class CommitmentOutputLink(val output: TxOut, val redeemScript: List, val commitmentOutput: T) : Comparable> { + data class CommitmentOutputLink(val output: TxOut, val redeemScript: List, val scriptTree: ScriptTreeAndInternalKey?, val commitmentOutput: T) : Comparable> { + + constructor(output: TxOut, redeemScript: List, commitmentOutput: T) : this(output, redeemScript, null, commitmentOutput) + + constructor(output: TxOut, scriptTree: ScriptTreeAndInternalKey, commitmentOutput: T) : this(output, listOf(), scriptTree, commitmentOutput) + /** * We sort HTLC outputs according to BIP69 + CLTV as tie-breaker for offered HTLC, we do this only for the outgoing * HTLC because we must agree with the remote on the order of HTLC-Timeout transactions even for identical HTLC outputs. @@ -317,9 +375,10 @@ object Transactions { remotePaymentPubkey: PublicKey, localHtlcPubkey: PublicKey, remoteHtlcPubkey: PublicKey, - spec: CommitmentSpec + spec: CommitmentSpec, + isTaprootChannel: Boolean ): TransactionsCommitmentOutputs { - val commitFee = commitTxFee(localDustLimit, spec) + val commitFee = commitTxFee(localDustLimit, spec, isTaprootChannel) val (toLocalAmount: Satoshi, toRemoteAmount: Satoshi) = if (localIsInitiator) { Pair(spec.toLocal.truncateToSatoshi() - commitFee, spec.toRemote.truncateToSatoshi()) @@ -329,50 +388,130 @@ object Transactions { val outputs = ArrayList>() - if (toLocalAmount >= localDustLimit) outputs.add( - CommitmentOutputLink( - TxOut(toLocalAmount, Script.pay2wsh(Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey))), - Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey), - CommitmentOutput.ToLocal - ) - ) + if (toLocalAmount >= localDustLimit) { + when (isTaprootChannel) { + true -> { + val toLocalScriptTree = Scripts.Taproot.toLocalScriptTree(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey) + outputs.add( + CommitmentOutputLink( + TxOut(toLocalAmount, Script.pay2tr(XonlyPublicKey(NUMS_POINT), toLocalScriptTree)), + ScriptTreeAndInternalKey(toLocalScriptTree, NUMS_POINT.xOnly()), + CommitmentOutput.ToLocal + ) + ) + } + + else -> outputs.add( + CommitmentOutputLink( + TxOut(toLocalAmount, Script.pay2wsh(Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey))), + Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey), + CommitmentOutput.ToLocal + ) + ) + } + } if (toRemoteAmount >= localDustLimit) { - outputs.add( - CommitmentOutputLink( - TxOut(toRemoteAmount, Script.pay2wsh(Scripts.toRemoteDelayed(remotePaymentPubkey))), - Scripts.toRemoteDelayed(remotePaymentPubkey), - CommitmentOutput.ToRemote + when (isTaprootChannel) { + true -> { + val toRemoteScriptTree = Scripts.Taproot.toRemoteScriptTree(remotePaymentPubkey) + outputs.add( + CommitmentOutputLink( + TxOut(toRemoteAmount, Script.pay2tr(XonlyPublicKey(NUMS_POINT), toRemoteScriptTree)), + ScriptTreeAndInternalKey(toRemoteScriptTree, NUMS_POINT.xOnly()), + CommitmentOutput.ToRemote + ) + ) + } + + else -> outputs.add( + CommitmentOutputLink( + TxOut(toRemoteAmount, Script.pay2wsh(Scripts.toRemoteDelayed(remotePaymentPubkey))), + Scripts.toRemoteDelayed(remotePaymentPubkey), + CommitmentOutput.ToRemote + ) ) - ) + + } } val untrimmedHtlcs = trimOfferedHtlcs(localDustLimit, spec).isNotEmpty() || trimReceivedHtlcs(localDustLimit, spec).isNotEmpty() - if (untrimmedHtlcs || toLocalAmount >= localDustLimit) - outputs.add( - CommitmentOutputLink( - TxOut(Commitments.ANCHOR_AMOUNT, Script.pay2wsh(Scripts.toAnchor(localFundingPubkey))), - Scripts.toAnchor(localFundingPubkey), - CommitmentOutput.ToLocalAnchor(localFundingPubkey) + if (untrimmedHtlcs || toLocalAmount >= localDustLimit) { + when (isTaprootChannel) { + true -> { + outputs.add( + CommitmentOutputLink( + TxOut(Commitments.ANCHOR_AMOUNT, Script.pay2tr(localDelayedPaymentPubkey.xOnly(), Scripts.Taproot.anchorScriptTree)), + Scripts.Taproot.anchorScript, + CommitmentOutput.ToLocalAnchor(localFundingPubkey) + ) + ) + } + + else -> outputs.add( + CommitmentOutputLink( + TxOut(Commitments.ANCHOR_AMOUNT, Script.pay2wsh(Scripts.toAnchor(localFundingPubkey))), + Scripts.toAnchor(localFundingPubkey), + CommitmentOutput.ToLocalAnchor(localFundingPubkey) + ) ) - ) - if (untrimmedHtlcs || toRemoteAmount >= localDustLimit) - outputs.add( - CommitmentOutputLink( - TxOut(Commitments.ANCHOR_AMOUNT, Script.pay2wsh(Scripts.toAnchor(remoteFundingPubkey))), - Scripts.toAnchor(remoteFundingPubkey), - CommitmentOutput.ToLocalAnchor(remoteFundingPubkey) + } + } + + if (untrimmedHtlcs || toRemoteAmount >= localDustLimit) { + when (isTaprootChannel) { + true -> outputs.add( + CommitmentOutputLink( + TxOut(Commitments.ANCHOR_AMOUNT, Script.pay2tr(remotePaymentPubkey.xOnly(), Scripts.Taproot.anchorScriptTree)), + Scripts.Taproot.anchorScript, + CommitmentOutput.ToLocalAnchor(remoteFundingPubkey) + ) ) - ) + + else -> outputs.add( + CommitmentOutputLink( + TxOut(Commitments.ANCHOR_AMOUNT, Script.pay2wsh(Scripts.toAnchor(remoteFundingPubkey))), + Scripts.toAnchor(remoteFundingPubkey), + CommitmentOutput.ToLocalAnchor(remoteFundingPubkey) + ) + ) + } + } trimOfferedHtlcs(localDustLimit, spec).forEach { htlc -> - val redeemScript = Scripts.htlcOffered(localHtlcPubkey, remoteHtlcPubkey, localRevocationPubkey, Crypto.ripemd160(htlc.add.paymentHash.toByteArray())) - outputs.add(CommitmentOutputLink(TxOut(htlc.add.amountMsat.truncateToSatoshi(), Script.pay2wsh(redeemScript)), redeemScript, OutHtlc(htlc))) + when (isTaprootChannel) { + true -> { + val offeredHtlcTree = Scripts.Taproot.offeredHtlcTree(localHtlcPubkey, remoteHtlcPubkey, htlc.add.paymentHash) + outputs.add( + CommitmentOutputLink( + TxOut(htlc.add.amountMsat.truncateToSatoshi(), Script.pay2tr(localRevocationPubkey.xOnly(), offeredHtlcTree)), ScriptTreeAndInternalKey(offeredHtlcTree, localRevocationPubkey.xOnly()), OutHtlc(htlc) + ) + ) + } + + else -> { + val redeemScript = Scripts.htlcOffered(localHtlcPubkey, remoteHtlcPubkey, localRevocationPubkey, Crypto.ripemd160(htlc.add.paymentHash.toByteArray())) + outputs.add(CommitmentOutputLink(TxOut(htlc.add.amountMsat.truncateToSatoshi(), Script.pay2wsh(redeemScript)), redeemScript, OutHtlc(htlc))) + } + } } trimReceivedHtlcs(localDustLimit, spec).forEach { htlc -> - val redeemScript = Scripts.htlcReceived(localHtlcPubkey, remoteHtlcPubkey, localRevocationPubkey, Crypto.ripemd160(htlc.add.paymentHash.toByteArray()), htlc.add.cltvExpiry) - outputs.add(CommitmentOutputLink(TxOut(htlc.add.amountMsat.truncateToSatoshi(), Script.pay2wsh(redeemScript)), redeemScript, InHtlc(htlc))) + when (isTaprootChannel) { + true -> { + val receivedHtlcTree = Scripts.Taproot.receivedHtlcTree(localHtlcPubkey, remoteHtlcPubkey, htlc.add.paymentHash, htlc.add.cltvExpiry) + outputs.add( + CommitmentOutputLink( + TxOut(htlc.add.amountMsat.truncateToSatoshi(), Script.pay2tr(localRevocationPubkey.xOnly(), receivedHtlcTree)), ScriptTreeAndInternalKey(receivedHtlcTree, localRevocationPubkey.xOnly()), InHtlc(htlc) + ) + ) + } + + else -> { + val redeemScript = Scripts.htlcReceived(localHtlcPubkey, remoteHtlcPubkey, localRevocationPubkey, Crypto.ripemd160(htlc.add.paymentHash.toByteArray()), htlc.add.cltvExpiry) + outputs.add(CommitmentOutputLink(TxOut(htlc.add.amountMsat.truncateToSatoshi(), Script.pay2wsh(redeemScript)), redeemScript, InHtlc(htlc))) + } + } } return outputs.apply { sort() } @@ -400,8 +539,14 @@ object Transactions { } sealed class TxResult { - data class Skipped(val why: TxGenerationSkipped) : TxResult() - data class Success(val result: T) : TxResult() + abstract fun map(f: (T) -> R): TxResult + + data class Skipped(val why: TxGenerationSkipped) : TxResult() { + override fun map(f: (T) -> R): TxResult = Skipped(why) + } + data class Success(val result: T) : TxResult() { + override fun map(f: (T) -> R): TxResult = Success(f(result)) + } } private fun makeHtlcTimeoutTx( @@ -412,7 +557,8 @@ object Transactions { localRevocationPubkey: PublicKey, toLocalDelay: CltvExpiryDelta, localDelayedPaymentPubkey: PublicKey, - feerate: FeeratePerKw + feerate: FeeratePerKw, + isTaprootChannel: Boolean ): TxResult { val fee = weight2fee(feerate, Commitments.HTLC_TIMEOUT_WEIGHT) val redeemScript = output.redeemScript @@ -421,14 +567,30 @@ object Transactions { return if (amount < localDustLimit) { TxResult.Skipped(TxGenerationSkipped.AmountBelowDustLimit) } else { - val input = InputInfo(OutPoint(commitTx, outputIndex.toLong()), commitTx.txOut[outputIndex], ByteVector(Script.write(redeemScript))) - val tx = Transaction( - version = 2, - txIn = listOf(TxIn(input.outPoint, ByteVector.empty, 1L)), - txOut = listOf(TxOut(amount, Script.pay2wsh(Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey)))), - lockTime = htlc.cltvExpiry.toLong() - ) - TxResult.Success(TransactionWithInputInfo.HtlcTx.HtlcTimeoutTx(input, tx, htlc.id)) + when (isTaprootChannel) { + true -> { + val input = InputInfo(OutPoint(commitTx, outputIndex.toLong()), commitTx.txOut[outputIndex], output.scriptTree!!.publicKeyScript, output.scriptTree) + val tree = ScriptTree.Leaf(0, Scripts.Taproot.toDelayScript(localDelayedPaymentPubkey, toLocalDelay)) + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(input.outPoint, ByteVector.empty, 1L)), + txOut = listOf(TxOut(amount, Script.pay2tr(localRevocationPubkey.xOnly(), tree))), + lockTime = htlc.cltvExpiry.toLong() + ) + TxResult.Success(TransactionWithInputInfo.HtlcTx.HtlcTimeoutTx(input, tx, htlc.id)) + } + + else -> { + val input = InputInfo(OutPoint(commitTx, outputIndex.toLong()), commitTx.txOut[outputIndex], ByteVector(Script.write(redeemScript))) + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(input.outPoint, ByteVector.empty, 1L)), + txOut = listOf(TxOut(amount, Script.pay2wsh(Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey)))), + lockTime = htlc.cltvExpiry.toLong() + ) + TxResult.Success(TransactionWithInputInfo.HtlcTx.HtlcTimeoutTx(input, tx, htlc.id)) + } + } } } @@ -440,7 +602,8 @@ object Transactions { localRevocationPubkey: PublicKey, toLocalDelay: CltvExpiryDelta, localDelayedPaymentPubkey: PublicKey, - feerate: FeeratePerKw + feerate: FeeratePerKw, + isTaprootChannel: Boolean ): TxResult { val fee = weight2fee(feerate, Commitments.HTLC_SUCCESS_WEIGHT) val redeemScript = output.redeemScript @@ -449,14 +612,30 @@ object Transactions { return if (amount < localDustLimit) { TxResult.Skipped(TxGenerationSkipped.AmountBelowDustLimit) } else { - val input = InputInfo(OutPoint(commitTx, outputIndex.toLong()), commitTx.txOut[outputIndex], ByteVector(Script.write(redeemScript))) - val tx = Transaction( - version = 2, - txIn = listOf(TxIn(input.outPoint, ByteVector.empty, 1L)), - txOut = listOf(TxOut(amount, Script.pay2wsh(Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey)))), - lockTime = 0 - ) - TxResult.Success(TransactionWithInputInfo.HtlcTx.HtlcSuccessTx(input, tx, htlc.paymentHash, htlc.id)) + when (isTaprootChannel) { + true -> { + val input = InputInfo(OutPoint(commitTx, outputIndex.toLong()), commitTx.txOut[outputIndex], output.scriptTree!!.publicKeyScript, output.scriptTree) + val tree = ScriptTree.Leaf(0, Scripts.Taproot.toDelayScript(localDelayedPaymentPubkey, toLocalDelay)) + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(input.outPoint, ByteVector.empty, 1L)), + txOut = listOf(TxOut(amount, Script.pay2tr(localRevocationPubkey.xOnly(), tree))), + lockTime = 0 + ) + TxResult.Success(TransactionWithInputInfo.HtlcTx.HtlcSuccessTx(input, tx, htlc.paymentHash, htlc.id)) + } + + else -> { + val input = InputInfo(OutPoint(commitTx, outputIndex.toLong()), commitTx.txOut[outputIndex], ByteVector(Script.write(redeemScript))) + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(input.outPoint, ByteVector.empty, 1L)), + txOut = listOf(TxOut(amount, Script.pay2wsh(Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey)))), + lockTime = 0 + ) + TxResult.Success(TransactionWithInputInfo.HtlcTx.HtlcSuccessTx(input, tx, htlc.paymentHash, htlc.id)) + } + } } } @@ -467,21 +646,22 @@ object Transactions { toLocalDelay: CltvExpiryDelta, localDelayedPaymentPubkey: PublicKey, feerate: FeeratePerKw, - outputs: TransactionsCommitmentOutputs + outputs: TransactionsCommitmentOutputs, + isTaprootChannel: Boolean ): List { val htlcTimeoutTxs = outputs .mapIndexedNotNull map@{ outputIndex, link -> val outHtlc = link.commitmentOutput as? OutHtlc ?: return@map null - val co = CommitmentOutputLink(link.output, link.redeemScript, outHtlc) - makeHtlcTimeoutTx(commitTx, co, outputIndex, localDustLimit, localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey, feerate) + val co = CommitmentOutputLink(link.output, link.redeemScript, link.scriptTree, outHtlc) + makeHtlcTimeoutTx(commitTx, co, outputIndex, localDustLimit, localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey, feerate, isTaprootChannel) } .mapNotNull { (it as? TxResult.Success)?.result } val htlcSuccessTxs = outputs .mapIndexedNotNull map@{ outputIndex, link -> val inHtlc = link.commitmentOutput as? InHtlc ?: return@map null - val co = CommitmentOutputLink(link.output, link.redeemScript, inHtlc) - makeHtlcSuccessTx(commitTx, co, outputIndex, localDustLimit, localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey, feerate) + val co = CommitmentOutputLink(link.output, link.redeemScript, link.scriptTree, inHtlc) + makeHtlcSuccessTx(commitTx, co, outputIndex, localDustLimit, localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey, feerate, isTaprootChannel) } .mapNotNull { (it as? TxResult.Success)?.result } @@ -594,6 +774,53 @@ object Transactions { } } +fun makeHtlcDelayedTx(htlcTx: Transaction, + localDustLimit: Satoshi, + localRevocationPubkey: PublicKey, + toLocalDelay: CltvExpiryDelta, + localDelayedPaymentPubkey: PublicKey, + localFinalScriptPubKey: ByteArray, + feeratePerKw: FeeratePerKw, + isTaprootChannel: Boolean): TxResult { + return when(isTaprootChannel) { + true -> { + val htlcTxTree = ScriptTree.Leaf (0, Scripts.Taproot.toDelayScript(localDelayedPaymentPubkey, toLocalDelay)) + val scriptTree = ScriptTreeAndInternalKey(htlcTxTree, localRevocationPubkey.xOnly()) + when (val pubkeyScriptIndex = findPubKeyScriptIndex(htlcTx, scriptTree.publicKeyScript.toByteArray())) { + is TxResult.Skipped -> TxResult.Skipped(pubkeyScriptIndex.why) + is TxResult.Success -> { + val outputIndex = pubkeyScriptIndex.result + val input = InputInfo(OutPoint(htlcTx, outputIndex.toLong()), htlcTx.txOut[outputIndex], scriptTree.publicKeyScript, ScriptTreeAndInternalKey(htlcTxTree, localRevocationPubkey.xOnly())) + // unsigned transaction + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(input.outPoint, ByteVector.empty, toLocalDelay.toLong())), + txOut = listOf(TxOut(Satoshi(0), localFinalScriptPubKey)), + lockTime = 0 + ) + val weight = run { + val witness = Script.witnessScriptPathPay2tr(localRevocationPubkey.xOnly(), htlcTxTree, ScriptWitness(listOf(ByteVector64.Zeroes)), htlcTxTree) + tx.updateWitness(0, witness).weight() + } + val fee = weight2fee(feeratePerKw, weight) + val amount = input.txOut.amount - fee + if (amount < localDustLimit) { + TxResult.Skipped(TxGenerationSkipped.AmountBelowDustLimit) + } else { + val tx1 = tx.copy(txOut = listOf(tx.txOut.first().copy(amount = amount))) + TxResult.Success(TransactionWithInputInfo.HtlcDelayedTx(input, tx1)) + } + } + } + } + else -> { + makeClaimLocalDelayedOutputTx(htlcTx, localDustLimit, localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey, localFinalScriptPubKey, feeratePerKw, isTaprootChannel).map { + it -> TransactionWithInputInfo.HtlcDelayedTx(it.input, it.tx) + } + } + } + } + fun makeClaimLocalDelayedOutputTx( delayedOutputTx: Transaction, localDustLimit: Satoshi, @@ -601,15 +828,31 @@ object Transactions { toLocalDelay: CltvExpiryDelta, localDelayedPaymentPubkey: PublicKey, localFinalScriptPubKey: ByteArray, - feerate: FeeratePerKw + feerate: FeeratePerKw, + isTaprootChannel: Boolean ): TxResult { - val redeemScript = Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey) - val pubkeyScript = Script.write(Script.pay2wsh(redeemScript)) + + val (redeemScript, pubkeyScript, scriptTree_opt) = when(isTaprootChannel) { + true -> { + val toLocalScriptTree = Scripts.Taproot.toLocalScriptTree(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey) + Triple( + Scripts.Taproot.toDelayScript(localDelayedPaymentPubkey, toLocalDelay), + Script.write(Script.pay2tr(XonlyPublicKey(NUMS_POINT), toLocalScriptTree)), + ScriptTreeAndInternalKey(toLocalScriptTree, NUMS_POINT.xOnly()) + ) + } + + else -> { + val redeemScript = Scripts.toLocalDelayed(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey) + Triple(redeemScript, Script.write(Script.pay2wsh(redeemScript)), null) + } + } + return when (val pubkeyScriptIndex = findPubKeyScriptIndex(delayedOutputTx, pubkeyScript)) { is TxResult.Skipped -> TxResult.Skipped(pubkeyScriptIndex.why) is TxResult.Success -> { val outputIndex = pubkeyScriptIndex.result - val input = InputInfo(OutPoint(delayedOutputTx, outputIndex.toLong()), delayedOutputTx.txOut[outputIndex], ByteVector(Script.write(redeemScript))) + val input = InputInfo(OutPoint(delayedOutputTx, outputIndex.toLong()), delayedOutputTx.txOut[outputIndex], ByteVector(Script.write(redeemScript)), scriptTree_opt) // unsigned transaction val tx = Transaction( version = 2, @@ -618,7 +861,14 @@ object Transactions { lockTime = 0 ) // compute weight with a dummy 73 bytes signature (the largest you can get) - val weight = addSigs(TransactionWithInputInfo.ClaimLocalDelayedOutputTx(input, tx), PlaceHolderSig).tx.weight() + val weight = when(isTaprootChannel) { + true -> { + val toLocalScriptTree = Scripts.Taproot.toLocalScriptTree(localRevocationPubkey, toLocalDelay, localDelayedPaymentPubkey) + val witness = Script.witnessScriptPathPay2tr(XonlyPublicKey(NUMS_POINT), toLocalScriptTree.left as ScriptTree.Leaf, ScriptWitness(listOf(ByteVector64.Zeroes)), toLocalScriptTree) + tx.updateWitness(0, witness).weight() + } + else -> addSigs(TransactionWithInputInfo.ClaimLocalDelayedOutputTx(input, tx), PlaceHolderSig).tx.weight() + } val fee = weight2fee(feerate, weight) val amount = input.txOut.amount - fee if (amount < localDustLimit) { @@ -811,6 +1061,39 @@ object Transactions { return sign(txInfo.tx, inputIndex, txInfo.input.redeemScript.toByteArray(), txInfo.input.txOut.amount, key, sigHash) } + fun partialSign( + key: PrivateKey, tx: Transaction, inputIndex: Int, spentOutputs: List, + localFundingPublicKey: PublicKey, remoteFundingPublicKey: PublicKey, + localNonce: Pair, remoteNextLocalNonce: IndividualNonce + ): Either { + val publicKeys = Scripts.sort(listOf(localFundingPublicKey, remoteFundingPublicKey)) + return Musig2.signTaprootInput(key, tx, inputIndex, spentOutputs, publicKeys, localNonce.first, listOf(localNonce.second, remoteNextLocalNonce), null) + } + + fun partialSign( + txinfo: TransactionWithInputInfo, key: PrivateKey, + localFundingPublicKey: PublicKey, remoteFundingPublicKey: PublicKey, + localNonce: Pair, remoteNextLocalNonce: IndividualNonce + ): Either { + val inputIndex = txinfo.tx.txIn.indexOfFirst { it.outPoint == txinfo.input.outPoint } + return partialSign(key, txinfo.tx, inputIndex, listOf(txinfo.input.txOut), localFundingPublicKey, remoteFundingPublicKey, localNonce, remoteNextLocalNonce) + } + + fun aggregatePartialSignatures( + txinfo: TransactionWithInputInfo, + localSig: ByteVector32, remoteSig: ByteVector32, + localFundingPublicKey: PublicKey, remoteFundingPublicKey: PublicKey, + localNonce: IndividualNonce, remoteNonce: IndividualNonce + ): Either { + return Musig2.aggregateTaprootSignatures( + listOf(localSig, remoteSig), txinfo.tx, txinfo.tx.txIn.indexOfFirst { it.outPoint == txinfo.input.outPoint }, + listOf(txinfo.input.txOut), + Scripts.sort(listOf(localFundingPublicKey, remoteFundingPublicKey)), + listOf(localNonce, remoteNonce), + null + ) + } + fun addSigs( commitTx: TransactionWithInputInfo.CommitTx, localFundingPubkey: PublicKey, @@ -838,7 +1121,14 @@ object Transactions { } fun addSigs(htlcTimeoutTx: TransactionWithInputInfo.HtlcTx.HtlcTimeoutTx, localSig: ByteVector64, remoteSig: ByteVector64): TransactionWithInputInfo.HtlcTx.HtlcTimeoutTx { - val witness = Scripts.witnessHtlcTimeout(localSig, remoteSig, htlcTimeoutTx.input.redeemScript) + val witness = when(htlcTimeoutTx.input.scriptTree) { + null -> Scripts.witnessHtlcTimeout(localSig, remoteSig, htlcTimeoutTx.input.redeemScript) + else -> { + val branch = htlcTimeoutTx.input.scriptTree.scriptTree as ScriptTree.Branch + val sigHash = (SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY).toByte() + Script.witnessScriptPathPay2tr(htlcTimeoutTx.input.scriptTree.internalKey, branch.left as ScriptTree.Leaf, ScriptWitness(listOf(remoteSig.concat(sigHash), localSig.concat(sigHash))), branch) + } + } return htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0, witness)) } @@ -852,6 +1142,14 @@ object Transactions { return claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0, witness)) } + fun addSigs(htlcDelayedTx: TransactionWithInputInfo.HtlcDelayedTx, localSig: ByteVector64): TransactionWithInputInfo.HtlcDelayedTx { + val witness = when(htlcDelayedTx.input.scriptTree) { + null -> witnessToLocalDelayedAfterDelay(localSig, htlcDelayedTx.input.redeemScript) + else -> Script.witnessScriptPathPay2tr(htlcDelayedTx.input.scriptTree.internalKey, htlcDelayedTx.input.scriptTree.scriptTree as ScriptTree.Leaf, ScriptWitness(listOf(localSig)), htlcDelayedTx.input.scriptTree.scriptTree) + } + return htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0, witness)) + } + fun addSigs(claimRemoteDelayed: TransactionWithInputInfo.ClaimRemoteCommitMainOutputTx.ClaimRemoteDelayedOutputTx, localSig: ByteVector64): TransactionWithInputInfo.ClaimRemoteCommitMainOutputTx.ClaimRemoteDelayedOutputTx { val witness = Scripts.witnessToRemoteDelayedAfterDelay(localSig, claimRemoteDelayed.input.redeemScript) return claimRemoteDelayed.copy(tx = claimRemoteDelayed.tx.updateWitness(0, witness)) @@ -872,6 +1170,10 @@ object Transactions { return closingTx.copy(tx = closingTx.tx.updateWitness(0, witness)) } + fun addAggregatedSignature(commitTx: TransactionWithInputInfo.CommitTx, aggregatedSignature: ByteVector64): TransactionWithInputInfo.CommitTx { + return commitTx.copy(tx = commitTx.tx.updateWitness(0, Script.witnessKeyPathPay2tr(aggregatedSignature))) + } + fun checkSpendable(txinfo: TransactionWithInputInfo): Try = runTrying { Transaction.correctlySpends(txinfo.tx, mapOf(txinfo.tx.txIn.first().outPoint to txinfo.input.txOut), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt index 0ac111bab..97bdfa564 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/ChannelTlv.kt @@ -1,6 +1,7 @@ package fr.acinq.lightning.wire import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.crypto.musig2.IndividualNonce import fr.acinq.bitcoin.io.Input import fr.acinq.bitcoin.io.Output import fr.acinq.lightning.Features @@ -9,6 +10,7 @@ import fr.acinq.lightning.ShortChannelId import fr.acinq.lightning.blockchain.fee.FeeratePerKw import fr.acinq.lightning.channel.ChannelType import fr.acinq.lightning.channel.Origin +import fr.acinq.lightning.channel.PartialSignatureWithNonce import fr.acinq.lightning.utils.msat import fr.acinq.lightning.utils.sat import fr.acinq.lightning.utils.toByteVector @@ -201,6 +203,23 @@ sealed class ChannelTlv : Tlv { override fun read(input: Input): PushAmountTlv = PushAmountTlv(LightningCodecs.tu64(input).msat) } } + + data class NextLocalNoncesTlv(val nonces: List) : ChannelTlv() { + override val tag: Long get() = NextLocalNoncesTlv.tag + + override fun write(out: Output) { + nonces.forEach { LightningCodecs.writeBytes(it.toByteArray(), out) } + } + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): NextLocalNoncesTlv { + val count = input.availableBytes / 66 + val nonces = (0 until count).map { IndividualNonce(LightningCodecs.bytes(input, 66)) } + return NextLocalNoncesTlv(nonces) + } + } + } } sealed class ChannelReadyTlv : Tlv { @@ -213,6 +232,19 @@ sealed class ChannelReadyTlv : Tlv { override fun read(input: Input): ShortChannelIdTlv = ShortChannelIdTlv(ShortChannelId(LightningCodecs.u64(input))) } } + + data class NextLocalNonceTlv(val nonce: IndividualNonce) : ChannelReadyTlv() { + override val tag: Long get() = NextLocalNonceTlv.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(nonce.toByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): NextLocalNonceTlv = NextLocalNonceTlv(IndividualNonce(LightningCodecs.bytes(input, 66))) + } + } } sealed class CommitSigTlv : Tlv { @@ -266,6 +298,27 @@ sealed class CommitSigTlv : Tlv { override fun read(input: Input): Batch = Batch(size = LightningCodecs.tu16(input)) } } + + data class PartialSignatureWithNonceTlv(val psig: PartialSignatureWithNonce) : CommitSigTlv() { + override val tag: Long get() = PartialSignatureWithNonceTlv.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(psig.partialSig, out) + LightningCodecs.writeBytes(psig.nonce.toByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 2 + override fun read(input: Input): PartialSignatureWithNonceTlv { + return PartialSignatureWithNonceTlv( + PartialSignatureWithNonce( + LightningCodecs.bytes(input, 32).byteVector32(), + IndividualNonce(LightningCodecs.bytes(input, 66)) + ) + ) + } + } + } } sealed class RevokeAndAckTlv : Tlv { @@ -278,6 +331,23 @@ sealed class RevokeAndAckTlv : Tlv { override fun read(input: Input): ChannelData = ChannelData(EncryptedChannelData(LightningCodecs.bytes(input, input.availableBytes).toByteVector())) } } + + data class NextLocalNoncesTlv(val nonces: List) : ChannelTlv() { + override val tag: Long get() = NextLocalNoncesTlv.tag + + override fun write(out: Output) { + nonces.forEach { LightningCodecs.writeBytes(it.toByteArray(), out) } + } + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): NextLocalNoncesTlv { + val count = input.availableBytes / 66 + val nonces = (0 until count).map { IndividualNonce(LightningCodecs.bytes(input, 66)) } + return NextLocalNoncesTlv(nonces) + } + } + } } sealed class ChannelReestablishTlv : Tlv { @@ -300,6 +370,23 @@ sealed class ChannelReestablishTlv : Tlv { override fun read(input: Input): ChannelData = ChannelData(EncryptedChannelData(LightningCodecs.bytes(input, input.availableBytes).toByteVector())) } } + + data class NextLocalNoncesTlv(val nonces: List) : ChannelTlv() { + override val tag: Long get() = NextLocalNoncesTlv.tag + + override fun write(out: Output) { + nonces.forEach { LightningCodecs.writeBytes(it.toByteArray(), out) } + } + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): NextLocalNoncesTlv { + val count = input.availableBytes / 66 + val nonces = (0 until count).map { IndividualNonce(LightningCodecs.bytes(input, 66)) } + return NextLocalNoncesTlv(nonces) + } + } + } } sealed class ShutdownTlv : Tlv { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt index 76cc0b7ee..c71e9f957 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/LightningMessages.kt @@ -708,6 +708,7 @@ data class OpenDualFundedChannel( ChannelTlv.RequestFunds.tag to ChannelTlv.RequestFunds as TlvValueReader, ChannelTlv.OriginTlv.tag to ChannelTlv.OriginTlv.Companion as TlvValueReader, ChannelTlv.PushAmountTlv.tag to ChannelTlv.PushAmountTlv.Companion as TlvValueReader, + ChannelTlv.NextLocalNoncesTlv.tag to ChannelTlv.NextLocalNoncesTlv.Companion as TlvValueReader, ) override fun read(input: Input): OpenDualFundedChannel = OpenDualFundedChannel( @@ -788,6 +789,7 @@ data class AcceptDualFundedChannel( ChannelTlv.RequireConfirmedInputsTlv.tag to ChannelTlv.RequireConfirmedInputsTlv as TlvValueReader, ChannelTlv.WillFund.tag to ChannelTlv.WillFund as TlvValueReader, ChannelTlv.PushAmountTlv.tag to ChannelTlv.PushAmountTlv.Companion as TlvValueReader, + ChannelTlv.NextLocalNoncesTlv.tag to ChannelTlv.NextLocalNoncesTlv.Companion as TlvValueReader, ) override fun read(input: Input): AcceptDualFundedChannel = AcceptDualFundedChannel( @@ -881,7 +883,9 @@ data class ChannelReady( const val type: Long = 36 @Suppress("UNCHECKED_CAST") - val readers = mapOf(ChannelReadyTlv.ShortChannelIdTlv.tag to ChannelReadyTlv.ShortChannelIdTlv.Companion as TlvValueReader) + val readers = mapOf( + ChannelReadyTlv.ShortChannelIdTlv.tag to ChannelReadyTlv.ShortChannelIdTlv.Companion as TlvValueReader, + ChannelReadyTlv.NextLocalNonceTlv.tag to ChannelReadyTlv.NextLocalNonceTlv.Companion as TlvValueReader) override fun read(input: Input) = ChannelReady( ByteVector32(LightningCodecs.bytes(input, 32)), @@ -1207,6 +1211,7 @@ data class CommitSig( CommitSigTlv.ChannelData.tag to CommitSigTlv.ChannelData.Companion as TlvValueReader, CommitSigTlv.AlternativeFeerateSigs.tag to CommitSigTlv.AlternativeFeerateSigs.Companion as TlvValueReader, CommitSigTlv.Batch.tag to CommitSigTlv.Batch.Companion as TlvValueReader, + CommitSigTlv.PartialSignatureWithNonceTlv.tag to CommitSigTlv.PartialSignatureWithNonceTlv.Companion as TlvValueReader, ) override fun read(input: Input): CommitSig { @@ -1244,7 +1249,10 @@ data class RevokeAndAck( const val type: Long = 133 @Suppress("UNCHECKED_CAST") - val readers = mapOf(RevokeAndAckTlv.ChannelData.tag to RevokeAndAckTlv.ChannelData.Companion as TlvValueReader) + val readers = mapOf( + RevokeAndAckTlv.ChannelData.tag to RevokeAndAckTlv.ChannelData.Companion as TlvValueReader, + RevokeAndAckTlv.NextLocalNoncesTlv.tag to RevokeAndAckTlv.NextLocalNoncesTlv.Companion as TlvValueReader + ) override fun read(input: Input): RevokeAndAck { return RevokeAndAck( @@ -1310,6 +1318,7 @@ data class ChannelReestablish( val readers = mapOf( ChannelReestablishTlv.ChannelData.tag to ChannelReestablishTlv.ChannelData.Companion as TlvValueReader, ChannelReestablishTlv.NextFunding.tag to ChannelReestablishTlv.NextFunding.Companion as TlvValueReader, + ChannelReestablishTlv.NextLocalNoncesTlv.tag to ChannelReestablishTlv.NextLocalNoncesTlv.Companion as TlvValueReader, ) override fun read(input: Input): ChannelReestablish { diff --git a/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt index f80517101..565e07b32 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt @@ -226,7 +226,7 @@ class NormalTestsCommon : LightningTestSuite() { val (_, alice4) = crossSign(bob3, alice3) val aliceCommit = alice4.commitments.active.first().localCommit assertTrue(aliceCommit.publishableTxs.commitTx.tx.txOut.all { txOut -> txOut.amount > 0.sat }) - val aliceBalance = aliceCommit.spec.toLocal - commitTxFeeMsat(alice4.commitments.params.localParams.dustLimit, aliceCommit.spec) + val aliceBalance = aliceCommit.spec.toLocal - commitTxFeeMsat(alice4.commitments.params.localParams.dustLimit, aliceCommit.spec, alice4.commitments.isTaprootChannel) assertTrue(aliceBalance >= 0.msat) assertTrue(aliceBalance < alice4.commitments.latest.localChannelReserve) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/transactions/AnchorOutputsTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/transactions/AnchorOutputsTestsCommon.kt index e57048707..2fa4d4b0f 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/transactions/AnchorOutputsTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/transactions/AnchorOutputsTestsCommon.kt @@ -184,7 +184,8 @@ class AnchorOutputsTestsCommon { remote_payment_privkey.publicKey(), local_htlc_privkey.publicKey(), remote_htlc_privkey.publicKey(), - spec + spec, + false ) val commitTx = Transactions.makeCommitTx( commitTxInput, @@ -201,7 +202,7 @@ class AnchorOutputsTestsCommon { val txs = testCase.HtlcDescs.map { it.ResolutionTx.txid to it.ResolutionTx }.toMap() val remoteHtlcSigs = testCase.HtlcDescs.map { it.ResolutionTx.txid to ByteVector(it.RemoteSigHex) }.toMap() - val htlcTxs = Transactions.makeHtlcTxs(commitTx.tx, 546.sat, local_revocation_pubkey, CltvExpiryDelta(144), local_delayedpubkey, spec.feerate, outputs) + val htlcTxs = Transactions.makeHtlcTxs(commitTx.tx, 546.sat, local_revocation_pubkey, CltvExpiryDelta(144), local_delayedpubkey, spec.feerate, outputs, isTaprootChannel = false) assertTrue { remoteHtlcSigs.keys.containsAll(htlcTxs.map { it.tx.txid }) } htlcTxs.forEach { htlcTx -> val localHtlcSig = Transactions.sign(htlcTx, local_htlc_privkey, SigHash.SIGHASH_ALL) diff --git a/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt index eccfc4e4d..3cafcdea6 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/transactions/TransactionsTestsCommon.kt @@ -20,7 +20,9 @@ import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.transactions.CommitmentOutput.OutHtlc import fr.acinq.lightning.transactions.Scripts.htlcOffered import fr.acinq.lightning.transactions.Scripts.htlcReceived +import fr.acinq.lightning.transactions.Scripts.musig2Aggregate import fr.acinq.lightning.transactions.Scripts.toLocalDelayed +import fr.acinq.lightning.transactions.Transactions.NUMS_POINT import fr.acinq.lightning.transactions.Transactions.PlaceHolderSig import fr.acinq.lightning.transactions.Transactions.TxGenerationSkipped.AmountBelowDustLimit import fr.acinq.lightning.transactions.Transactions.TxGenerationSkipped.OutputNotFound @@ -46,6 +48,7 @@ import fr.acinq.lightning.transactions.Transactions.makeClaimRemoteDelayedOutput import fr.acinq.lightning.transactions.Transactions.makeClosingTx import fr.acinq.lightning.transactions.Transactions.makeCommitTx import fr.acinq.lightning.transactions.Transactions.makeCommitTxOutputs +import fr.acinq.lightning.transactions.Transactions.makeHtlcDelayedTx import fr.acinq.lightning.transactions.Transactions.makeHtlcPenaltyTx import fr.acinq.lightning.transactions.Transactions.makeHtlcTxs import fr.acinq.lightning.transactions.Transactions.makeMainPenaltyTx @@ -68,7 +71,7 @@ class TransactionsTestsCommon : LightningTestSuite() { private val remotePaymentPriv = PrivateKey(randomBytes32()) private val localHtlcPriv = PrivateKey(randomBytes32()) private val remoteHtlcPriv = PrivateKey(randomBytes32()) - private val commitInput = Funding.makeFundingInputInfo(TxId(randomBytes32()), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey()) + private val commitInput = Funding.makeFundingInputInfo(TxId(randomBytes32()), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), false) private val toLocalDelay = CltvExpiryDelta(144) private val localDustLimit = 546.sat private val feerate = FeeratePerKw(22_000.sat) @@ -105,7 +108,7 @@ class TransactionsTestsCommon : LightningTestSuite() { IncomingHtlc(UpdateAddHtlc(ByteVector32.Zeroes, 0, 800000.msat, ByteVector32.Zeroes, CltvExpiry(551), TestConstants.emptyOnionPacket)) ) val spec = CommitmentSpec(htlcs, feerate = FeeratePerKw(5_000.sat), toLocal = 0.msat, toRemote = 0.msat) - val fee = commitTxFee(546.sat, spec) + val fee = commitTxFee(546.sat, spec, isTaprootChannel = false) assertEquals(8000.sat, fee) } @@ -116,13 +119,14 @@ class TransactionsTestsCommon : LightningTestSuite() { val toLocalDelay = CltvExpiryDelta(144) val feeratePerKw = FeeratePerKw.MinimumFeeratePerKw val blockHeight = 400_000 + val isTaprootChannel = false run { // ClaimHtlcDelayedTx // first we create a fake htlcSuccessOrTimeoutTx tx, containing only the output that will be spent by the ClaimDelayedOutputTx val pubKeyScript = write(pay2wsh(toLocalDelayed(localRevocationPriv.publicKey(), toLocalDelay, localPaymentPriv.publicKey()))) val htlcSuccessOrTimeoutTx = Transaction(version = 0, txIn = listOf(TxIn(OutPoint(TxId(ByteVector32.Zeroes), 0), TxIn.SEQUENCE_FINAL)), txOut = listOf(TxOut(20000.sat, pubKeyScript)), lockTime = 0) - val claimHtlcDelayedTx = makeClaimLocalDelayedOutputTx(htlcSuccessOrTimeoutTx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localPaymentPriv.publicKey(), finalPubKeyScript, feeratePerKw) + val claimHtlcDelayedTx = makeClaimLocalDelayedOutputTx(htlcSuccessOrTimeoutTx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localPaymentPriv.publicKey(), finalPubKeyScript, feeratePerKw, isTaprootChannel) assertTrue(claimHtlcDelayedTx is Success, "is $claimHtlcDelayedTx") // we use dummy signatures to compute the weight val weight = Transaction.weight(addSigs(claimHtlcDelayedTx.result, PlaceHolderSig).tx) @@ -174,7 +178,8 @@ class TransactionsTestsCommon : LightningTestSuite() { remotePaymentPriv.publicKey(), localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), - spec + spec, + isTaprootChannel ) val commitTx = Transaction(version = 0, txIn = listOf(TxIn(OutPoint(TxId(ByteVector32.Zeroes), 0), TxIn.SEQUENCE_FINAL)), txOut = outputs.map { it.output }, lockTime = 0) val claimHtlcSuccessTx = @@ -203,7 +208,8 @@ class TransactionsTestsCommon : LightningTestSuite() { remotePaymentPriv.publicKey(), localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), - spec + spec, + isTaprootChannel ) val commitTx = Transaction(version = 0, txIn = listOf(TxIn(OutPoint(TxId(ByteVector32.Zeroes), 0), TxIn.SEQUENCE_FINAL)), txOut = outputs.map { it.output }, lockTime = 0) val claimHtlcTimeoutTx = @@ -216,10 +222,211 @@ class TransactionsTestsCommon : LightningTestSuite() { } } + @Test + fun `build taproot transactions`() { + + // funding tx sends to musig2 aggregate of local and remote funding keys + val fundingTxOutpoint = OutPoint(TxId(randomBytes32()), 0) + val fundingOutput = TxOut(Satoshi(100000), Script.pay2tr(musig2Aggregate(localFundingPriv.publicKey(), remoteFundingPriv.publicKey()), null as ByteVector32?)) + + // to-local output script tree, with 2 leaves + val toLocalScriptTree = ScriptTree.Branch( + ScriptTree.Leaf(0, Scripts.Taproot.toDelayScript(localDelayedPaymentPriv.publicKey(), toLocalDelay)), + ScriptTree.Leaf(1, Scripts.Taproot.toRevokeScript(localRevocationPriv.publicKey(), localDelayedPaymentPriv.publicKey())), + ) + + // to-remote output script tree, with a single leaf + val toRemoteScriptTree = ScriptTree.Leaf(0, Scripts.Taproot.toRemoteScript(remotePaymentPriv.publicKey())) + + // offered HTLC + val preimage = ByteVector32.fromValidHex("0101010101010101010101010101010101010101010101010101010101010101") + val paymentHash = sha256(preimage).byteVector32() + val offeredHtlcTree = Scripts.Taproot.offeredHtlcTree(localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), paymentHash) + val receivedHtlcTree = Scripts.Taproot.receivedHtlcTree(localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), paymentHash, CltvExpiry(300)) + + val txNumber = 0x404142434445L + val (sequence, lockTime) = encodeTxNumber(txNumber) + val commitTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(fundingTxOutpoint, sequence)), + txOut = listOf( + TxOut(30000000.sat, Script.pay2tr(XonlyPublicKey(NUMS_POINT), toLocalScriptTree)), + TxOut(40000000.sat, Script.pay2tr(XonlyPublicKey(NUMS_POINT), toRemoteScriptTree)), + TxOut(330.sat, Script.pay2tr(localDelayedPaymentPriv.xOnlyPublicKey(), Scripts.Taproot.anchorScriptTree)), + TxOut(330.sat, Script.pay2tr(remotePaymentPriv.xOnlyPublicKey(), Scripts.Taproot.anchorScriptTree)), + TxOut(100.sat, Script.pay2tr(localRevocationPriv.xOnlyPublicKey(), offeredHtlcTree)), + TxOut(150.sat, Script.pay2tr(localRevocationPriv.xOnlyPublicKey(), receivedHtlcTree)) + ), + lockTime + ) + + val localNonce = Musig2.generateNonce(randomBytes32(), localFundingPriv, listOf(localFundingPriv.publicKey())) + val remoteNonce = Musig2.generateNonce(randomBytes32(), remoteFundingPriv, listOf(remoteFundingPriv.publicKey())) + + val localPartialSig = Musig2.signTaprootInput( + localFundingPriv, + tx, 0, listOf(fundingOutput), + Scripts.sort(listOf(localFundingPriv.publicKey(), remoteFundingPriv.publicKey())), + localNonce.first, listOf(localNonce.second, remoteNonce.second), + null + ).right!! + + val remotePartialSig = Musig2.signTaprootInput( + remoteFundingPriv, + tx, 0, listOf(fundingOutput), + Scripts.sort(listOf(localFundingPriv.publicKey(), remoteFundingPriv.publicKey())), + remoteNonce.first, listOf(localNonce.second, remoteNonce.second), + null + ).right!! + + val aggSig = Musig2.aggregateTaprootSignatures( + listOf(localPartialSig, remotePartialSig), tx, 0, + listOf(fundingOutput), + Scripts.sort(listOf(localFundingPriv.publicKey(), remoteFundingPriv.publicKey())), + listOf(localNonce.second, remoteNonce.second), + null + ).right!! + + tx.updateWitness(0, Script.witnessKeyPathPay2tr(aggSig)) + } + Transaction.correctlySpends(commitTx, mapOf(fundingTxOutpoint to fundingOutput), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + val finalPubKeyScript = Script.write(Script.pay2wpkh(PrivateKey(randomBytes32()).publicKey())) + + val spendToLocalOutputTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 0), sequence = toLocalDelay.toLong())), + txOut = listOf(TxOut(30000000.sat, finalPubKeyScript)), + lockTime = 0 + ) + val sig = Transaction.signInputTaprootScriptPath(localDelayedPaymentPriv, tx, 0, listOf(commitTx.txOut[0]), SigHash.SIGHASH_DEFAULT, toLocalScriptTree.left.hash()) + val witness = Script.witnessScriptPathPay2tr(XonlyPublicKey(NUMS_POINT), toLocalScriptTree.left as ScriptTree.Leaf, ScriptWitness(listOf(sig)), toLocalScriptTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(spendToLocalOutputTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + + val spendToRemoteOutputTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 1), sequence = 1)), + txOut = listOf(TxOut(40000000.sat, finalPubKeyScript)), + lockTime = 0 + ) + val sig = Transaction.signInputTaprootScriptPath(remotePaymentPriv, tx, 0, listOf(commitTx.txOut[1]), SigHash.SIGHASH_DEFAULT, toRemoteScriptTree.hash()) + val witness = Script.witnessScriptPathPay2tr(XonlyPublicKey(NUMS_POINT), toRemoteScriptTree, ScriptWitness(listOf(sig)), toRemoteScriptTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(spendToRemoteOutputTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + val spendLocalAnchorTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 2), sequence = TxIn.SEQUENCE_FINAL)), + txOut = listOf(TxOut(330.sat, finalPubKeyScript)), + lockTime = 0 + ) + val sig = Transaction.signInputTaprootKeyPath(localDelayedPaymentPriv, tx, 0, listOf(commitTx.txOut[2]), SigHash.SIGHASH_DEFAULT, Scripts.Taproot.anchorScriptTree) + val witness = Script.witnessKeyPathPay2tr(sig) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(spendLocalAnchorTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + val spendRemoteAnchorTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 3), listOf(), sequence = TxIn.SEQUENCE_FINAL)), + txOut = listOf(TxOut(330.sat, finalPubKeyScript)), + lockTime = 0 + ) + val sig = Transaction.signInputTaprootKeyPath(remotePaymentPriv, tx, 0, listOf(commitTx.txOut[3]), SigHash.SIGHASH_DEFAULT, Scripts.Taproot.anchorScriptTree) + val witness = Script.witnessKeyPathPay2tr(sig) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(spendRemoteAnchorTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + val mainPenaltyTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 0), sequence = TxIn.SEQUENCE_FINAL)), + txOut = listOf(TxOut(330.sat, finalPubKeyScript)), + lockTime = 0 + ) + val sig = Transaction.signInputTaprootScriptPath(localRevocationPriv, tx, 0, listOf(commitTx.txOut[0]), SigHash.SIGHASH_DEFAULT, toLocalScriptTree.right.hash()) + val witness = Script.witnessScriptPathPay2tr(XonlyPublicKey(NUMS_POINT), toLocalScriptTree.right as ScriptTree.Leaf, ScriptWitness(listOf(sig)), toLocalScriptTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(mainPenaltyTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + // sign and spend received HTLC with HTLC-Success tx + val htlcSuccessTree = ScriptTree.Leaf(0, Scripts.Taproot.toDelayScript(localDelayedPaymentPriv.publicKey(), toLocalDelay)) + val htlcSuccessTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 5), sequence = 1)), + txOut = listOf(TxOut(150.sat, Script.pay2tr(localRevocationPriv.xOnlyPublicKey(), htlcSuccessTree))), + lockTime = 0 + ) + val sigHash = SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY + val localSig = Transaction.signInputTaprootScriptPath(localHtlcPriv, tx, 0, listOf(commitTx.txOut[5]), sigHash, receivedHtlcTree.right.hash()).toByteArray() + sigHash.toByte() + val remoteSig = Transaction.signInputTaprootScriptPath(remoteHtlcPriv, tx, 0, listOf(commitTx.txOut[5]), sigHash, receivedHtlcTree.right.hash()).toByteArray() + sigHash.toByte() + val witness = Script.witnessScriptPathPay2tr(localRevocationPriv.xOnlyPublicKey(), receivedHtlcTree.right as ScriptTree.Leaf, ScriptWitness(listOf(remoteSig.byteVector(), localSig.byteVector(), preimage)), receivedHtlcTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(htlcSuccessTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + val spendHtlcSuccessTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(htlcSuccessTx, 0), sequence = toLocalDelay.toLong())), + txOut = listOf(TxOut(150.sat, finalPubKeyScript)), + lockTime = 0 + ) + val localSig = Transaction.signInputTaprootScriptPath(localDelayedPaymentPriv, tx, 0, listOf(htlcSuccessTx.txOut[0]), SigHash.SIGHASH_DEFAULT, htlcSuccessTree.hash()) + val witness = Script.witnessScriptPathPay2tr(localRevocationPriv.xOnlyPublicKey(), htlcSuccessTree, ScriptWitness(listOf(localSig)), htlcSuccessTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(spendHtlcSuccessTx, listOf(htlcSuccessTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + // sign and spend offered HTLC with HTLC-Timeout tx + val htlcTimeoutTree = htlcSuccessTree + val htlcTimeoutTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(commitTx, 4), sequence = TxIn.SEQUENCE_FINAL)), + txOut = listOf(TxOut(100.sat, Script.pay2tr(localRevocationPriv.xOnlyPublicKey(), htlcTimeoutTree))), + lockTime = CltvExpiry(300).toLong() + ) + val sigHash = SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY + val localSig = Transaction.signInputTaprootScriptPath(localHtlcPriv, tx, 0, listOf(commitTx.txOut[4]), sigHash, offeredHtlcTree.left.hash()).toByteArray() + sigHash.toByte() + val remoteSig = Transaction.signInputTaprootScriptPath(remoteHtlcPriv, tx, 0, listOf(commitTx.txOut[4]), sigHash, offeredHtlcTree.left.hash()).toByteArray() + sigHash.toByte() + val witness = Script.witnessScriptPathPay2tr(localRevocationPriv.xOnlyPublicKey(), offeredHtlcTree.left as ScriptTree.Leaf, ScriptWitness(listOf(remoteSig.byteVector(), localSig.byteVector())), offeredHtlcTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(htlcTimeoutTx, listOf(commitTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + val spendHtlcTimeoutTx = run { + val tx = Transaction( + version = 2, + txIn = listOf(TxIn(OutPoint(htlcTimeoutTx, 0), sequence = toLocalDelay.toLong())), + txOut = listOf(TxOut(100.sat, finalPubKeyScript)), + lockTime = 0 + ) + val localSig = Transaction.signInputTaprootScriptPath(localDelayedPaymentPriv, tx, 0, listOf(htlcTimeoutTx.txOut[0]), SigHash.SIGHASH_DEFAULT, htlcTimeoutTree.hash()) + val witness = Script.witnessScriptPathPay2tr(localRevocationPriv.xOnlyPublicKey(), htlcTimeoutTree, ScriptWitness(listOf(localSig)), htlcTimeoutTree) + tx.updateWitness(0, witness) + } + Transaction.correctlySpends(spendHtlcTimeoutTx, listOf(htlcTimeoutTx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) + + } + @Test fun `generate valid commitment and htlc transactions`() { + val isTaprootChannel = false val finalPubKeyScript = write(pay2wpkh(PrivateKey(ByteVector32("01".repeat(32))).publicKey())) - val commitInput = Funding.makeFundingInputInfo(TxId(ByteVector32("02".repeat(32))), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey()) + val commitInput = Funding.makeFundingInputInfo(TxId(ByteVector32("02".repeat(32))), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), isTaprootChannel) // htlc1 and htlc2 are regular IN/OUT htlcs val paymentPreimage1 = ByteVector32("03".repeat(32)) @@ -268,7 +475,8 @@ class TransactionsTestsCommon : LightningTestSuite() { remotePaymentPriv.publicKey(), localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), - spec + spec, + isTaprootChannel ) val commitTxNumber = 0x404142434445L @@ -286,7 +494,7 @@ class TransactionsTestsCommon : LightningTestSuite() { val check = ((commitTx.tx.txIn.first().sequence and 0xffffffL) shl 24) or (commitTx.tx.lockTime and 0xffffffL) assertEquals(commitTxNumber, check xor num) } - val htlcTxs = makeHtlcTxs(commitTx.tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), spec.feerate, outputs) + val htlcTxs = makeHtlcTxs(commitTx.tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), spec.feerate, outputs, isTaprootChannel) assertEquals(4, htlcTxs.size) val htlcSuccessTxs = htlcTxs.filterIsInstance() assertEquals(2, htlcSuccessTxs.size) // htlc2 and htlc4 @@ -307,13 +515,252 @@ class TransactionsTestsCommon : LightningTestSuite() { } run { // local spends delayed output of htlc1 timeout tx - val claimHtlcDelayed = makeClaimLocalDelayedOutputTx(htlcTimeoutTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + val claimHtlcDelayed = makeClaimLocalDelayedOutputTx(htlcTimeoutTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) + assertTrue(claimHtlcDelayed is Success, "is $claimHtlcDelayed") + val localSig = sign(claimHtlcDelayed.result, localDelayedPaymentPriv) + val signedTx = addSigs(claimHtlcDelayed.result, localSig) + assertTrue(checkSpendable(signedTx).isSuccess) + // local can't claim delayed output of htlc3 timeout tx because it is below the dust limit + val claimHtlcDelayed1 = makeClaimLocalDelayedOutputTx(htlcTimeoutTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) + assertEquals(Skipped(OutputNotFound), claimHtlcDelayed1) + } + run { + // remote spends local->remote htlc1/htlc3 output directly in case of success + for ((htlc, paymentPreimage) in listOf(htlc1 to paymentPreimage1, htlc3 to paymentPreimage3)) { + val claimHtlcSuccessTx = + makeClaimHtlcSuccessTx(commitTx.tx, outputs, localDustLimit, remoteHtlcPriv.publicKey(), localHtlcPriv.publicKey(), localRevocationPriv.publicKey(), finalPubKeyScript, htlc, feerate) + assertTrue(claimHtlcSuccessTx is Success, "is $claimHtlcSuccessTx") + val localSig = sign(claimHtlcSuccessTx.result, remoteHtlcPriv) + val signed = addSigs(claimHtlcSuccessTx.result, localSig, paymentPreimage) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + } + } + run { + // local spends remote->local htlc2/htlc4 output with htlc success tx using payment preimage + for ((htlcSuccessTx, paymentPreimage) in listOf(htlcSuccessTxs[1] to paymentPreimage2, htlcSuccessTxs[0] to paymentPreimage4)) { + val localSig = sign(htlcSuccessTx, localHtlcPriv) + val remoteSig = sign(htlcSuccessTx, remoteHtlcPriv, SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY) + val signedTx = addSigs(htlcSuccessTx, localSig, remoteSig, paymentPreimage) + val csResult = checkSpendable(signedTx) + assertTrue(csResult.isSuccess, "is $csResult") + // check remote sig + assertTrue(checkSig(htlcSuccessTx, remoteSig, remoteHtlcPriv.publicKey(), SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY)) + } + } + run { + // local spends delayed output of htlc2 success tx + val claimHtlcDelayed = makeClaimLocalDelayedOutputTx(htlcSuccessTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) assertTrue(claimHtlcDelayed is Success, "is $claimHtlcDelayed") val localSig = sign(claimHtlcDelayed.result, localDelayedPaymentPriv) val signedTx = addSigs(claimHtlcDelayed.result, localSig) + val csResult = checkSpendable(signedTx) + assertTrue(csResult.isSuccess, "is $csResult") + // local can't claim delayed output of htlc4 timeout tx because it is below the dust limit + val claimHtlcDelayed1 = makeClaimLocalDelayedOutputTx(htlcSuccessTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) + assertEquals(Skipped(AmountBelowDustLimit), claimHtlcDelayed1) + } + run { + // remote spends main output + val claimP2WPKHOutputTx = makeClaimRemoteDelayedOutputTx(commitTx.tx, localDustLimit, remotePaymentPriv.publicKey(), finalPubKeyScript.toByteVector(), feerate) + assertTrue(claimP2WPKHOutputTx is Success, "is $claimP2WPKHOutputTx") + val localSig = sign(claimP2WPKHOutputTx.result, remotePaymentPriv) + val signedTx = addSigs(claimP2WPKHOutputTx.result, localSig) + val csResult = checkSpendable(signedTx) + assertTrue(csResult.isSuccess, "is $csResult") + } + run { + // remote spends htlc1's htlc-timeout tx with revocation key + val claimHtlcDelayedPenaltyTxs = makeClaimDelayedOutputPenaltyTxs(htlcTimeoutTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + assertEquals(1, claimHtlcDelayedPenaltyTxs.size) + val claimHtlcDelayedPenaltyTx = claimHtlcDelayedPenaltyTxs.first() + assertTrue(claimHtlcDelayedPenaltyTx is Success, "is $claimHtlcDelayedPenaltyTx") + val sig = sign(claimHtlcDelayedPenaltyTx.result, localRevocationPriv) + val signed = addSigs(claimHtlcDelayedPenaltyTx.result, sig) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + // remote can't claim revoked output of htlc3's htlc-timeout tx because it is below the dust limit + val claimHtlcDelayedPenaltyTxsSkipped = makeClaimDelayedOutputPenaltyTxs(htlcTimeoutTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + assertEquals(listOf(Skipped(AmountBelowDustLimit)), claimHtlcDelayedPenaltyTxsSkipped) + } + run { + // remote spends remote->local htlc output directly in case of timeout + val claimHtlcTimeoutTx = + makeClaimHtlcTimeoutTx(commitTx.tx, outputs, localDustLimit, remoteHtlcPriv.publicKey(), localHtlcPriv.publicKey(), localRevocationPriv.publicKey(), finalPubKeyScript, htlc2, feerate) + assertTrue(claimHtlcTimeoutTx is Success, "is $claimHtlcTimeoutTx") + val remoteSig = sign(claimHtlcTimeoutTx.result, remoteHtlcPriv) + val signed = addSigs(claimHtlcTimeoutTx.result, remoteSig) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + } + run { + // remote spends htlc2's htlc-success tx with revocation key + val claimHtlcDelayedPenaltyTxs = makeClaimDelayedOutputPenaltyTxs(htlcSuccessTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + assertEquals(1, claimHtlcDelayedPenaltyTxs.size) + val claimHtlcDelayedPenaltyTx = claimHtlcDelayedPenaltyTxs.first() + assertTrue(claimHtlcDelayedPenaltyTx is Success, "is $claimHtlcDelayedPenaltyTx") + val sig = sign(claimHtlcDelayedPenaltyTx.result, localRevocationPriv) + val signed = addSigs(claimHtlcDelayedPenaltyTx.result, sig) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + // remote can't claim revoked output of htlc4's htlc-success tx because it is below the dust limit + val claimHtlcDelayedPenaltyTxsSkipped = makeClaimDelayedOutputPenaltyTxs(htlcSuccessTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + assertEquals(listOf(Skipped(AmountBelowDustLimit)), claimHtlcDelayedPenaltyTxsSkipped) + } + run { + // remote spends all htlc txs aggregated in a single tx + val txIn = htlcTimeoutTxs.flatMap { it.tx.txIn } + htlcSuccessTxs.flatMap { it.tx.txIn } + val txOut = htlcTimeoutTxs.flatMap { it.tx.txOut } + htlcSuccessTxs.flatMap { it.tx.txOut } + val aggregatedHtlcTx = Transaction(2, txIn, txOut, 0) + val claimHtlcDelayedPenaltyTxs = makeClaimDelayedOutputPenaltyTxs(aggregatedHtlcTx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + assertEquals(4, claimHtlcDelayedPenaltyTxs.size) + val skipped = claimHtlcDelayedPenaltyTxs.filterIsInstance>() + assertEquals(2, skipped.size) + val claimed = claimHtlcDelayedPenaltyTxs.filterIsInstance>() + assertEquals(2, claimed.size) + assertEquals(2, claimed.map { it.result.input.outPoint }.toSet().size) + } + run { + // remote spends offered HTLC output with revocation key + val script = write(htlcOffered(localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), localRevocationPriv.publicKey(), ripemd160(htlc1.paymentHash))) + val htlcOutputIndex = outputs.indexOfFirst { + val outHtlc = (it.commitmentOutput as? OutHtlc)?.outgoingHtlc?.add + outHtlc != null && outHtlc.id == htlc1.id + } + val htlcPenaltyTx = makeHtlcPenaltyTx(commitTx.tx, htlcOutputIndex, script, localDustLimit, finalPubKeyScript, feerate) + assertTrue(htlcPenaltyTx is Success, "is $htlcPenaltyTx") + val sig = sign(htlcPenaltyTx.result, localRevocationPriv) + val signed = addSigs(htlcPenaltyTx.result, sig, localRevocationPriv.publicKey()) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + } + run { + // remote spends received HTLC output with revocation key + val script = write(htlcReceived(localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), localRevocationPriv.publicKey(), ripemd160(htlc2.paymentHash), htlc2.cltvExpiry)) + val htlcOutputIndex = outputs.indexOfFirst { + val inHtlc = (it.commitmentOutput as? CommitmentOutput.InHtlc)?.incomingHtlc?.add + inHtlc != null && inHtlc.id == htlc2.id + } + val htlcPenaltyTx = makeHtlcPenaltyTx(commitTx.tx, htlcOutputIndex, script, localDustLimit, finalPubKeyScript, feerate) + assertTrue(htlcPenaltyTx is Success, "is $htlcPenaltyTx") + val sig = sign(htlcPenaltyTx.result, localRevocationPriv) + val signed = addSigs(htlcPenaltyTx.result, sig, localRevocationPriv.publicKey()) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + } + } + + @Test + fun `generate valid commitment and htlc transactions -- simple taproot channels`() { + val isTaprootChannel = true + val finalPubKeyScript = write(pay2wpkh(PrivateKey(ByteVector32("01".repeat(32))).publicKey())) + val commitInput = Funding.makeFundingInputInfo(TxId(ByteVector32("02".repeat(32))), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), isTaprootChannel) + + // htlc1 and htlc2 are regular IN/OUT htlcs + val paymentPreimage1 = ByteVector32("03".repeat(32)) + val htlc1 = UpdateAddHtlc(ByteVector32.Zeroes, 0, 100.mbtc.toMilliSatoshi(), ByteVector32(sha256(paymentPreimage1)), CltvExpiry(300), TestConstants.emptyOnionPacket) + val paymentPreimage2 = ByteVector32("04".repeat(32)) + val htlc2 = UpdateAddHtlc(ByteVector32.Zeroes, 1, 200.mbtc.toMilliSatoshi(), ByteVector32(sha256(paymentPreimage2)), CltvExpiry(300), TestConstants.emptyOnionPacket) + // htlc3 and htlc4 are dust htlcs IN/OUT htlcs, with an amount large enough to be included in the commit tx, but too small to be claimed at 2nd stage + val paymentPreimage3 = ByteVector32("05".repeat(32)) + val htlc3 = UpdateAddHtlc( + ByteVector32.Zeroes, + 2, + (localDustLimit + weight2fee(feerate, Commitments.HTLC_TIMEOUT_WEIGHT)).toMilliSatoshi(), + ByteVector32(sha256(paymentPreimage3)), + CltvExpiry(300), + TestConstants.emptyOnionPacket + ) + val paymentPreimage4 = ByteVector32("06".repeat(32)) + val htlc4 = UpdateAddHtlc( + ByteVector32.Zeroes, + 3, + (localDustLimit + weight2fee(feerate, Commitments.HTLC_SUCCESS_WEIGHT)).toMilliSatoshi(), + ByteVector32(sha256(paymentPreimage4)), + CltvExpiry(300), + TestConstants.emptyOnionPacket + ) + val spec = CommitmentSpec( + htlcs = setOf( + OutgoingHtlc(htlc1), + IncomingHtlc(htlc2), + OutgoingHtlc(htlc3), + IncomingHtlc(htlc4) + ), + feerate = feerate, + toLocal = 400.mbtc.toMilliSatoshi(), + toRemote = 300.mbtc.toMilliSatoshi() + ) + + val outputs = makeCommitTxOutputs( + localFundingPriv.publicKey(), + remoteFundingPriv.publicKey(), + true, + localDustLimit, + localRevocationPriv.publicKey(), + toLocalDelay, + localDelayedPaymentPriv.publicKey(), + remotePaymentPriv.publicKey(), + localHtlcPriv.publicKey(), + remoteHtlcPriv.publicKey(), + spec, + isTaprootChannel + ) + val localNonce = Musig2.generateNonce(randomBytes32(), localFundingPriv, listOf(localFundingPriv.publicKey())) + val remoteNonce = Musig2.generateNonce(randomBytes32(), remoteFundingPriv, listOf(remoteFundingPriv.publicKey())) + val commitTxNumber = 0x404142434445L + val commitTx = run { + val txInfo = makeCommitTx(commitInput, commitTxNumber, localPaymentPriv.publicKey(), remotePaymentPriv.publicKey(), true, outputs) + when (isTaprootChannel) { + true -> { + val localSig = Transactions.partialSign(txInfo, localFundingPriv, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), localNonce, remoteNonce.second).right!! + val remoteSig = Transactions.partialSign(txInfo, remoteFundingPriv, remoteFundingPriv.publicKey(), localFundingPriv.publicKey(), remoteNonce, localNonce.second).right!! + val aggSig = Transactions.aggregatePartialSignatures(txInfo, localSig, remoteSig, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), localNonce.second, remoteNonce.second).right!! + Transactions.addAggregatedSignature(txInfo, aggSig) + } + else -> { + val localSig = sign(txInfo, localPaymentPriv) + val remoteSig = sign(txInfo, remotePaymentPriv) + addSigs(txInfo, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), localSig, remoteSig) + } + } + } + + run { + assertEquals(commitTxNumber, getCommitTxNumber(commitTx.tx, true, localPaymentPriv.publicKey(), remotePaymentPriv.publicKey())) + val hash = sha256(localPaymentPriv.publicKey().value + remotePaymentPriv.publicKey().value) + val num = Pack.int64BE(hash.takeLast(8).toByteArray()) and 0xffffffffffffL + val check = ((commitTx.tx.txIn.first().sequence and 0xffffffL) shl 24) or (commitTx.tx.lockTime and 0xffffffL) + assertEquals(commitTxNumber, check xor num) + } + val htlcTxs = makeHtlcTxs(commitTx.tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), spec.feerate, outputs, isTaprootChannel) + assertEquals(4, htlcTxs.size) + val htlcSuccessTxs = htlcTxs.filterIsInstance() + assertEquals(2, htlcSuccessTxs.size) // htlc2 and htlc4 + assertEquals(setOf(1L, 3L), htlcSuccessTxs.map { it.htlcId }.toSet()) + val htlcTimeoutTxs = htlcTxs.filterIsInstance() + assertEquals(2, htlcTimeoutTxs.size) // htlc1 and htlc3 + assertEquals(setOf(0L, 2L), htlcTimeoutTxs.map { it.htlcId }.toSet()) + + run { + // either party spends local->remote htlc output with htlc timeout tx + for (htlcTimeoutTx in htlcTimeoutTxs) { + val localSig = htlcTimeoutTx.sign(localHtlcPriv) + val remoteSig = htlcTimeoutTx.sign(remoteHtlcPriv, SigHash.SIGHASH_SINGLE or SigHash.SIGHASH_ANYONECANPAY) + val signed = addSigs(htlcTimeoutTx, localSig, remoteSig) + val csResult = checkSpendable(signed) + assertTrue(csResult.isSuccess, "is $csResult") + } + } + run { + // local spends delayed output of htlc1 timeout tx + val claimHtlcDelayed = makeHtlcDelayedTx(htlcTimeoutTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) + assertTrue(claimHtlcDelayed is Success, "is $claimHtlcDelayed") + val localSig = claimHtlcDelayed.result.sign(localDelayedPaymentPriv) + val signedTx = addSigs(claimHtlcDelayed.result, localSig) assertTrue(checkSpendable(signedTx).isSuccess) // local can't claim delayed output of htlc3 timeout tx because it is below the dust limit - val claimHtlcDelayed1 = makeClaimLocalDelayedOutputTx(htlcTimeoutTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localPaymentPriv.publicKey(), finalPubKeyScript, feerate) + val claimHtlcDelayed1 = makeHtlcDelayedTx(htlcTimeoutTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) assertEquals(Skipped(OutputNotFound), claimHtlcDelayed1) } run { @@ -342,14 +789,14 @@ class TransactionsTestsCommon : LightningTestSuite() { } run { // local spends delayed output of htlc2 success tx - val claimHtlcDelayed = makeClaimLocalDelayedOutputTx(htlcSuccessTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + val claimHtlcDelayed = makeClaimLocalDelayedOutputTx(htlcSuccessTxs[1].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) assertTrue(claimHtlcDelayed is Success, "is $claimHtlcDelayed") val localSig = sign(claimHtlcDelayed.result, localDelayedPaymentPriv) val signedTx = addSigs(claimHtlcDelayed.result, localSig) val csResult = checkSpendable(signedTx) assertTrue(csResult.isSuccess, "is $csResult") // local can't claim delayed output of htlc4 timeout tx because it is below the dust limit - val claimHtlcDelayed1 = makeClaimLocalDelayedOutputTx(htlcSuccessTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate) + val claimHtlcDelayed1 = makeClaimLocalDelayedOutputTx(htlcSuccessTxs[0].tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), finalPubKeyScript, feerate, isTaprootChannel) assertEquals(Skipped(AmountBelowDustLimit), claimHtlcDelayed1) } run { @@ -574,7 +1021,7 @@ class TransactionsTestsCommon : LightningTestSuite() { val remotePaymentPriv = PrivateKey.fromHex("a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6") val localHtlcPriv = PrivateKey.fromHex("a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7a7") val remoteHtlcPriv = PrivateKey.fromHex("a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8") - val commitInput = Funding.makeFundingInputInfo(TxId("a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0"), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey()) + val commitInput = Funding.makeFundingInputInfo(TxId("a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0"), 0, 1.btc, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), isTaprootChannel = false) // htlc1 and htlc2 are two regular incoming HTLCs with different amounts. // htlc2 and htlc3 have the same amounts and should be sorted according to their scriptPubKey @@ -602,6 +1049,7 @@ class TransactionsTestsCommon : LightningTestSuite() { ) val commitTxNumber = 0x404142434446L + val isTaprootChannel = false val (commitTx, outputs, htlcTxs) = run { val outputs = makeCommitTxOutputs( @@ -615,13 +1063,14 @@ class TransactionsTestsCommon : LightningTestSuite() { remotePaymentPriv.publicKey(), localHtlcPriv.publicKey(), remoteHtlcPriv.publicKey(), - spec + spec, + isTaprootChannel ) val txInfo = makeCommitTx(commitInput, commitTxNumber, localPaymentPriv.publicKey(), remotePaymentPriv.publicKey(), true, outputs) val localSig = sign(txInfo, localPaymentPriv) val remoteSig = sign(txInfo, remotePaymentPriv) val commitTx = addSigs(txInfo, localFundingPriv.publicKey(), remoteFundingPriv.publicKey(), localSig, remoteSig) - val htlcTxs = makeHtlcTxs(commitTx.tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), feerate, outputs) + val htlcTxs = makeHtlcTxs(commitTx.tx, localDustLimit, localRevocationPriv.publicKey(), toLocalDelay, localDelayedPaymentPriv.publicKey(), feerate, outputs, isTaprootChannel) Triple(commitTx, outputs, htlcTxs) }