Skip to content

Commit

Permalink
Rework TxComplete to use implicit ordering for musig2 nonces
Browse files Browse the repository at this point in the history
Instead of sending an explicit serialId -> nonce map, we send a list of public nonces ordered by serial id.
This matches how signatures are sent in TxSignatures.
  • Loading branch information
sstone committed Nov 27, 2023
1 parent ce75299 commit 504c49d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 31 deletions.
47 changes: 31 additions & 16 deletions src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,11 @@ data class FundingContributions(val inputs: List<InteractiveTxInput.Outgoing>, v
)

fun weight(walletInputs: List<WalletState.Utxo>): Int = walletInputs.sumOf {
when {
Script.isPay2wsh(it.previousTx.txOut[it.outputIndex].publicKeyScript.toByteArray()) -> Transactions.swapInputWeight
else -> Transactions.swapInputWeightMusig2
}
when {
Script.isPay2wsh(it.previousTx.txOut[it.outputIndex].publicKeyScript.toByteArray()) -> Transactions.swapInputWeight
else -> Transactions.swapInputWeightMusig2
}
}

/** We always randomize the order of inputs and outputs. */
private fun sortFundingContributions(params: InteractiveTxParams, inputs: List<InteractiveTxInput.Outgoing>, outputs: List<InteractiveTxOutput.Outgoing>): FundingContributions {
Expand All @@ -392,7 +392,7 @@ data class FundingContributions(val inputs: List<InteractiveTxInput.Outgoing>, v
when (input) {
is InteractiveTxInput.LocalOnly -> input.copy(serialId = serialId)
is InteractiveTxInput.LocalSwapIn -> input.copy(serialId = serialId)
is InteractiveTxInput.LocalMusig2SwapIn-> input.copy(serialId = serialId)
is InteractiveTxInput.LocalMusig2SwapIn -> input.copy(serialId = serialId)
is InteractiveTxInput.Shared -> input.copy(serialId = serialId)
}
}
Expand Down Expand Up @@ -462,6 +462,16 @@ data class SharedTransaction(
val previousOutputsMap = sharedOutput + localOutputs + remoteOutputs
val previousOutputs = unsignedTx.txIn.map { previousOutputsMap[it.outPoint]!! }.toList()

// nonces that we've received for all musig2 swap-in
val receivedNonces: Map<Long, PublicNonce> = when (session.txCompleteReceived) {
null -> mapOf()
else -> (localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>() + remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>())
.sortedBy { it.serialId }
.zip(session.txCompleteReceived.publicNonces)
.associate { it.first.serialId to it.second }
}


// If we are swapping funds in, we provide our partial signatures to the corresponding inputs.
val swapUserSigs = unsignedTx.txIn.mapIndexed { i, txIn ->
localInputs
Expand All @@ -477,8 +487,8 @@ data class SharedTransaction(
?.let { input ->
val userNonce = input.secretNonce
require(session.txCompleteReceived != null)
val serverNonce = session.txCompleteReceived.publicNonces[input.serialId]
require(serverNonce != null)
val serverNonce = receivedNonces[input.serialId]
require(serverNonce != null) { "missing server nonce for input ${input.serialId}" }
val commonNonce = PublicNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce))
TxSignatures.Companion.PartialSignature(keyManager.swapInOnChainWallet.signSwapInputUserMusig2(unsignedTx, i, previousOutputs, userNonce, serverNonce), commonNonce)
}
Expand All @@ -504,8 +514,8 @@ data class SharedTransaction(
val serverKey = keyManager.swapInOnChainWallet.localServerPrivateKey(remoteNodeId)
val userNonce = input.secretNonce
require(session.txCompleteReceived != null)
val serverNonce = session.txCompleteReceived.publicNonces[input.serialId]
require(serverNonce != null)
val serverNonce = receivedNonces[input.serialId]
require(serverNonce != null) { "missing server nonce for input ${input.serialId}" }
val commonNonce = PublicNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce))
val swapInProtocol = SwapInProtocolMusig2(input.swapInParams.userKey, serverKey.publicKey(), input.swapInParams.userRefundKey, input.swapInParams.refundDelay)
TxSignatures.Companion.PartialSignature(swapInProtocol.signSwapInputServer(unsignedTx, i, previousOutputs, serverNonce, serverKey, userNonce), commonNonce)
Expand Down Expand Up @@ -568,7 +578,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over
val localSwapTxInMusig2 = tx.localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>().sortedBy { i -> i.serialId }.zip(localSigs.swapInUserPartialSigs.zip(remoteSigs.swapInServerPartialSigs)).map { (i, sigs) ->
val (userSig, serverSig) = sigs
val swapInProtocol = SwapInProtocolMusig2(i.swapInParams)
require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce){ "aggregated public nonces mismatch for local input ${i.serialId}"}
require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce) { "aggregated public nonces mismatch for local input ${i.serialId}" }
val commonNonce = userSig.aggregatedPublicNonce
val unsignedTx = tx.buildUnsignedTx()
val ctx = swapInProtocol.signingCtx(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce)
Expand All @@ -587,7 +597,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over
val remoteSwapTxInMusig2 = tx.remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>().sortedBy { i -> i.serialId }.zip(remoteSigs.swapInUserPartialSigs.zip(localSigs.swapInServerPartialSigs)).map { (i, sigs) ->
val (userSig, serverSig) = sigs
val swapInProtocol = SwapInProtocolMusig2(i.swapInParams)
require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce){ "aggregated public nonces mismatch for remote input ${i.serialId}"}
require(userSig.aggregatedPublicNonce == serverSig.aggregatedPublicNonce) { "aggregated public nonces mismatch for remote input ${i.serialId}" }
val commonNonce = userSig.aggregatedPublicNonce
val unsignedTx = tx.buildUnsignedTx()
val ctx = swapInProtocol.signingCtx(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce)
Expand Down Expand Up @@ -689,10 +699,10 @@ data class InteractiveTxSession(
null -> {
// generate a new secret nonce for each musig2 new swapin every time we send TxComplete
val localMusig2SwapIns = localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>()
val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap()
val localNonces = localMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }
val remoteMusig2SwapIns = remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>()
val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }.toMap()
val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces))
val remoteNonces = remoteMusig2SwapIns.map { it.serialId to it.secretNonce.publicNonce() }
val txComplete = TxComplete(fundingParams.channelId, (localNonces + remoteNonces).sortedBy { it.first }.map { it.second })
val next = copy(txCompleteSent = txComplete)
if (next.isComplete) {
Pair(next, next.validateTx(txComplete))
Expand Down Expand Up @@ -885,11 +895,16 @@ data class InteractiveTxSession(
}
sharedInputs.first()
}
val receivedNonces = (localInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>() + remoteInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>())
.sortedBy { it.serialId }
.zip(txCompleteReceived.publicNonces)
.associate { it.first.serialId to it.second }

localOnlyInputs.filterIsInstance<InteractiveTxInput.LocalMusig2SwapIn>().forEach {
txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
receivedNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
}
remoteOnlyInputs.filterIsInstance<InteractiveTxInput.RemoteSwapInMusig2>().forEach {
txCompleteReceived.publicNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
receivedNonces[it.serialId] ?: return InteractiveTxSessionAction.MissingNonce(fundingParams.channelId, it.serialId)
}
val sharedTx = SharedTransaction(sharedInput, sharedOutput, localOnlyInputs, remoteOnlyInputs, localOnlyOutputs, remoteOnlyOutputs, fundingParams.lockTime)
val tx = sharedTx.buildUnsignedTx()
Expand Down
19 changes: 6 additions & 13 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/InteractiveTxTlv.kt
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,20 @@ sealed class TxRemoveInputTlv : Tlv
sealed class TxRemoveOutputTlv : Tlv

sealed class TxCompleteTlv : Tlv {
data class Nonces(val nonces: Map<Long, PublicNonce>): TxCompleteTlv() {
/** nonces for all Musig2 swap-in inputs, ordered by serial id */
data class Nonces(val nonces: List<PublicNonce>): TxCompleteTlv() {
override val tag: Long get() = Nonces.tag

override fun write(out: Output) {
LightningCodecs.writeU16(nonces.size, out)
nonces.forEach { (serialId, nonce) ->
LightningCodecs.writeBigSize(serialId, out)
LightningCodecs.writeBytes(nonce.toByteArray(), out)
}
nonces.forEach { LightningCodecs.writeBytes(it.toByteArray(), out) }
}

companion object : TlvValueReader<Nonces> {
const val tag: Long = 101
override fun read(input: Input): Nonces {
val noncesCount = LightningCodecs.u16(input)
val nonces = (1..noncesCount).map {
val serialId = LightningCodecs.bigSize(input)
val nonce = PublicNonce.fromBin(LightningCodecs.bytes(input, 66))
serialId to nonce
}
return Nonces(nonces.toMap())
val count = input.availableBytes / 66
val nonces = (0 until count).map { PublicNonce.fromBin(LightningCodecs.bytes(input, 66)) }
return Nonces(nonces)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,9 @@ data class TxComplete(
) : InteractiveTxConstructionMessage(), HasChannelId {
override val type: Long get() = TxComplete.type

val publicNonces: Map<Long, PublicNonce> = tlvs.get<TxCompleteTlv.Nonces>()?.nonces?.toMap() ?: mapOf()
val publicNonces: List<PublicNonce> = tlvs.get<TxCompleteTlv.Nonces>()?.nonces ?: listOf()

constructor(channelId: ByteVector32, publicNonces: Map<Long, PublicNonce>) : this(channelId, TlvStream(TxCompleteTlv.Nonces(publicNonces)))
constructor(channelId: ByteVector32, publicNonces: List<PublicNonce>) : this(channelId, TlvStream(TxCompleteTlv.Nonces(publicNonces)))

override fun write(out: Output) {
LightningCodecs.writeBytes(channelId.toByteArray(), out)
Expand Down

0 comments on commit 504c49d

Please sign in to comment.