Skip to content

Commit

Permalink
Do not send the previous tx for swap-in inputs
Browse files Browse the repository at this point in the history
They use taproot v1, providing the tx output and not the entire tx is safe (see #579).
Here we add the swap-in input output and txout to the swap-in TLV, so this change does not interface with proposed changes to the LN spec.
  • Loading branch information
sstone committed Jan 29, 2024
1 parent 3529964 commit c0a6d5a
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ data class WalletState(val addresses: Map<String, List<Utxo>>) {
data class Utxo(val txId: TxId, val outputIndex: Int, val blockHeight: Long, val previousTx: Transaction) {
val outPoint = OutPoint(previousTx, outputIndex.toLong())
val amount = previousTx.txOut[outputIndex].amount
val txOut = previousTx.txOut[outputIndex]
}

/**
Expand Down
46 changes: 28 additions & 18 deletions src/commonMain/kotlin/fr/acinq/lightning/channel/InteractiveTx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ sealed class InteractiveTxInput {
sealed interface Incoming

sealed class Local : InteractiveTxInput(), Outgoing {
abstract val previousTx: Transaction
abstract val previousTx: Transaction?
abstract val previousTxOutput: Long
override val txOut: TxOut
get() = previousTx.txOut[previousTxOutput.toInt()]
}

/** A local-only input that funds the interactive transaction. */
data class LocalOnly(override val serialId: Long, override val previousTx: Transaction, override val previousTxOutput: Long, override val sequence: UInt) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
override val txOut: TxOut
get() = previousTx.txOut[previousTxOutput.toInt()]
}

/** A local input that funds the interactive transaction, coming from a 2-of-2 swap-in transaction. */
Expand All @@ -122,12 +122,18 @@ sealed class InteractiveTxInput {

data class LocalSwapIn(
override val serialId: Long,
override val previousTx: Transaction,
override val previousTxOutput: Long,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParams
) : Local() {
override val outPoint: OutPoint = OutPoint(previousTx, previousTxOutput)
override val outPoint: OutPoint
get() = swapInParams.outPoint
override val previousTx: Transaction?
get() = null
override val txOut: TxOut
get() = swapInParams.txOut

override val previousTxOutput: Long
get() = outPoint.index
}

/**
Expand All @@ -144,11 +150,15 @@ sealed class InteractiveTxInput {

data class RemoteSwapIn(
override val serialId: Long,
override val outPoint: OutPoint,
override val txOut: TxOut,
override val sequence: UInt,
val swapInParams: TxAddInputTlv.SwapInParams
) : Remote()
) : Remote() {
override val txOut: TxOut
get() = swapInParams.txOut

override val outPoint: OutPoint
get() = swapInParams.outPoint
}

/** The shared input can be added by us or by our peer, depending on who initiated the protocol. */
data class Shared(
Expand Down Expand Up @@ -287,10 +297,8 @@ data class FundingContributions(val inputs: List<InteractiveTxInput.Outgoing>, v

else -> InteractiveTxInput.LocalSwapIn(
0,
i.previousTx.stripInputWitnesses(),
i.outputIndex.toLong(),
0xfffffffdU,
TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay),
TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay, i.outPoint, i.txOut),
)
}
}
Expand Down Expand Up @@ -538,7 +546,7 @@ data class FullySignedSharedTransaction(override val tx: SharedTransaction, over
val ctx = swapInProtocol.session(unsignedTx, unsignedTx.txIn.indexOfFirst { it.outPoint == i.outPoint }, unsignedTx.txIn.map { tx.spentOutputs[it.outPoint]!! }, commonNonce)
val commonSig = ctx.add(listOf(userSig.sig, serverSig.sig))
val witness = swapInProtocol.witness(commonSig)
Pair(i.serialId, TxIn(OutPoint(i.previousTx, i.previousTxOutput), ByteVector.empty, i.sequence.toLong(), witness))
Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), witness))
}

val remoteOnlyTxIn = tx.remoteOnlyInputs().sortedBy { i -> i.serialId }.zip(remoteSigs.witnesses).map { (i, w) -> Pair(i.serialId, TxIn(i.outPoint, ByteVector.empty, i.sequence.toLong(), w)) }
Expand Down Expand Up @@ -674,7 +682,7 @@ data class InteractiveTxSession(
}

is InteractiveTxInput.LocalSwapIn -> {
val swapInParams = TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay)
val swapInParams = TxAddInputTlv.SwapInParams(swapInKeys.userPublicKey, swapInKeys.remoteServerPublicKey, swapInKeys.userRefundPublicKey, swapInKeys.refundDelay, msg.value.outPoint, msg.value.txOut)
TxAddInput(fundingParams.channelId, msg.value.serialId, msg.value.previousTx, msg.value.previousTxOutput, msg.value.sequence, TlvStream(swapInParams))
}

Expand Down Expand Up @@ -711,14 +719,16 @@ data class InteractiveTxSession(
return Either.Left(InteractiveTxSessionAction.DuplicateSerialId(message.channelId, message.serialId))
}
// We check whether this is the shared input or a remote input.
val input = when (message.previousTx) {
null -> {
val input = when {
message.previousTx == null && message.swapInParams != null -> {
InteractiveTxInput.RemoteSwapIn(message.serialId, message.sequence, message.swapInParams)
}
message.previousTx == null -> {
val expectedSharedOutpoint = fundingParams.sharedInput?.info?.outPoint ?: return Either.Left(InteractiveTxSessionAction.PreviousTxMissing(message.channelId, message.serialId))
val receivedSharedOutpoint = message.sharedInput ?: return Either.Left(InteractiveTxSessionAction.PreviousTxMissing(message.channelId, message.serialId))
if (expectedSharedOutpoint != receivedSharedOutpoint) return Either.Left(InteractiveTxSessionAction.PreviousTxMissing(message.channelId, message.serialId))
InteractiveTxInput.Shared(message.serialId, receivedSharedOutpoint, fundingParams.sharedInput.info.txOut.publicKeyScript, message.sequence, previousFunding.toLocal, previousFunding.toRemote)
}

else -> {
if (message.previousTx.txOut.size <= message.previousTxOutput) {
return Either.Left(InteractiveTxSessionAction.InputOutOfBounds(message.channelId, message.serialId, message.previousTx.txid, message.previousTxOutput))
Expand All @@ -735,7 +745,7 @@ data class InteractiveTxSession(
val txOut = message.previousTx.txOut[message.previousTxOutput.toInt()]
when {
message.swapInParams != null -> {
InteractiveTxInput.RemoteSwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParams)
InteractiveTxInput.RemoteSwapIn(message.serialId, message.sequence, message.swapInParams)
}

message.swapInParamsLegacy != null -> InteractiveTxInput.RemoteLegacySwapIn(message.serialId, outpoint, txOut, message.sequence, message.swapInParamsLegacy.userKey, message.swapInParamsLegacy.serverKey, message.swapInParamsLegacy.refundDelay)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ object Deserialization {
)
0x03 -> InteractiveTxInput.LocalSwapIn(
serialId = readNumber(),
previousTx = readTransaction(),
previousTxOutput = readNumber(),
sequence = readNumber().toUInt(),
swapInParams = TxAddInputTlv.SwapInParams.read(this)
)
Expand All @@ -272,8 +270,6 @@ object Deserialization {
)
0x03 -> InteractiveTxInput.RemoteSwapIn(
serialId = readNumber(),
outPoint = readOutPoint(),
txOut = TxOut.read(readDelimitedByteArray()),
sequence = readNumber().toUInt(),
swapInParams = TxAddInputTlv.SwapInParams.read(this)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@ object Serialization {
is InteractiveTxInput.LocalSwapIn -> i.run {
write(0x03)
writeNumber(serialId)
writeBtcObject(previousTx)
writeNumber(previousTxOutput)
writeNumber(sequence.toLong())
swapInParams.write(this@writeLocalInteractiveTxInput)
}
Expand All @@ -313,8 +311,6 @@ object Serialization {
is InteractiveTxInput.RemoteSwapIn -> i.run {
write(0x03)
writeNumber(serialId)
writeBtcObject(outPoint)
writeBtcObject(txOut)
writeNumber(sequence.toLong())
swapInParams.write(this@writeRemoteInteractiveTxInput)
}
Expand Down
12 changes: 10 additions & 2 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/InteractiveTxTlv.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ sealed class TxAddInputTlv : Tlv {
}

/** When adding a swap-in input to an interactive-tx, the user needs to provide the corresponding script parameters. */
data class SwapInParams(val userKey: PublicKey, val serverKey: PublicKey, val userRefundKey: PublicKey, val refundDelay: Int) : TxAddInputTlv() {
data class SwapInParams(val userKey: PublicKey, val serverKey: PublicKey, val userRefundKey: PublicKey, val refundDelay: Int, val outPoint: OutPoint, val txOut: TxOut) : TxAddInputTlv() {
override val tag: Long get() = SwapInParams.tag
override fun write(out: Output) {
LightningCodecs.writeBytes(userKey.value, out)
LightningCodecs.writeBytes(serverKey.value, out)
LightningCodecs.writeBytes(userRefundKey.value, out)
LightningCodecs.writeU32(refundDelay, out)
val blob1 = OutPoint.write(outPoint)
LightningCodecs.writeU16(blob1.size, out)
LightningCodecs.writeBytes(blob1, out)
val blob2 = TxOut.write(txOut)
LightningCodecs.writeU16(blob2.size, out)
LightningCodecs.writeBytes(blob2, out)
}

companion object : TlvValueReader<SwapInParams> {
Expand All @@ -59,7 +65,9 @@ sealed class TxAddInputTlv : Tlv {
PublicKey(LightningCodecs.bytes(input, 33)),
PublicKey(LightningCodecs.bytes(input, 33)),
PublicKey(LightningCodecs.bytes(input, 33)),
LightningCodecs.u32(input)
LightningCodecs.u32(input),
OutPoint.read(LightningCodecs.bytes(input, LightningCodecs.u16(input))),
TxOut.read(LightningCodecs.bytes(input, LightningCodecs.u16(input)))
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,18 @@ class SwapInManagerTestsCommon : LightningTestSuite() {
@Test
fun `swap funds -- ignore inputs from pending channel`() {
val (waitForFundingSigned, _) = WaitForFundingSignedTestsCommon.init()
val inputs = waitForFundingSigned.state.signingSession.fundingTx.tx.localInputs
val wallet = run {
val utxos = waitForFundingSigned.state.signingSession.fundingTx.tx.localInputs.map { i -> WalletState.Utxo(i.outPoint.txid, i.outPoint.index.toInt(), 100, i.previousTx) }
val parentTxs = inputs.associate { it.outPoint.txid to Transaction(version = 2, txIn = listOf(), txOut = listOf(it.txOut), lockTime = 0) }
val utxos = inputs.map { i -> WalletState.Utxo(i.outPoint.txid, i.outPoint.index.toInt(), 100, parentTxs[i.outPoint.txid]!!) }
WalletState(mapOf(dummyAddress to utxos))
}
val mgr = SwapInManager(listOf(waitForFundingSigned.state), logger)
val cmd = SwapInCommand.TrySwapIn(currentBlockHeight = 150, wallet = wallet, swapInParams = SwapInParams(minConfirmations = 5, maxConfirmations = 720, refundDelay = 900), trustedTxs = emptySet())
mgr.process(cmd).also { assertNull(it) }

// The pending channel is aborted: we can reuse those inputs.
mgr.process(SwapInCommand.UnlockWalletInputs(wallet.utxos.map { it.outPoint }.toSet()))
mgr.process(SwapInCommand.UnlockWalletInputs(inputs.map { it.outPoint }.toSet()))
mgr.process(cmd).also { assertNotNull(it) }
}

Expand All @@ -156,15 +158,16 @@ class SwapInManagerTestsCommon : LightningTestSuite() {
val inputs = alice1.commitments.active.map { it.localFundingStatus }.filterIsInstance<LocalFundingStatus.UnconfirmedFundingTx>().flatMap { it.sharedTx.tx.localInputs }
assertEquals(3, inputs.size) // 1 initial funding input and 2 splice inputs
val wallet = run {
val utxos = inputs.map { i -> WalletState.Utxo(i.outPoint.txid, i.outPoint.index.toInt(), 100, i.previousTx) }
val parentTxs = inputs.associate { it.outPoint.txid to Transaction(version = 2, txIn = listOf(), txOut = listOf(it.txOut), lockTime = 0) }
val utxos = inputs.map { i -> WalletState.Utxo(i.outPoint.txid, i.outPoint.index.toInt(), 100, parentTxs[i.outPoint.txid]!!) }
WalletState(mapOf(dummyAddress to utxos))
}
val mgr = SwapInManager(listOf(alice1.state), logger)
val cmd = SwapInCommand.TrySwapIn(currentBlockHeight = 150, wallet = wallet, swapInParams = SwapInParams(minConfirmations = 5, maxConfirmations = 720, refundDelay = 900), trustedTxs = emptySet())
mgr.process(cmd).also { assertNull(it) }

// The channel is aborted: we can reuse those inputs.
mgr.process(SwapInCommand.UnlockWalletInputs(wallet.utxos.map { it.outPoint }.toSet()))
mgr.process(SwapInCommand.UnlockWalletInputs(inputs.map { it.outPoint }.toSet()))
mgr.process(cmd).also { result ->
assertNotNull(result)
assertEquals(3, result.walletInputs.size)
Expand All @@ -186,7 +189,8 @@ class SwapInManagerTestsCommon : LightningTestSuite() {
assertEquals(1, alice3.commitments.all.size)
assertIs<LocalFundingStatus.ConfirmedFundingTx>(alice3.commitments.latest.localFundingStatus)
val wallet = run {
val utxos = inputs.map { i -> WalletState.Utxo(i.outPoint.txid, i.outPoint.index.toInt(), 100, i.previousTx) }
val parentTxs = inputs.associate { it.outPoint.txid to Transaction(version = 2, txIn = listOf(), txOut = listOf(it.txOut), lockTime = 0) }
val utxos = inputs.map { i -> WalletState.Utxo(i.outPoint.txid, i.outPoint.index.toInt(), 100, parentTxs[i.outPoint.txid]!!) }
WalletState(mapOf(dummyAddress to utxos))
}
val mgr = SwapInManager(listOf(alice3.state), logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class InteractiveTxTestsCommon : LightningTestSuite() {
assertEquals(signedTx.lockTime, 42)
assertEquals(signedTx.txIn.size, 4)
assertEquals(signedTx.txOut.size, 3)
Transaction.correctlySpends(signedTx, (sharedTxA.sharedTx.localInputs + sharedTxB.sharedTx.localInputs).map { it.previousTx }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
Transaction.correctlySpends(signedTx, (sharedTxA.sharedTx.localInputs + sharedTxB.sharedTx.localInputs).associate { it.outPoint to it.txOut }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
val feerate = Transactions.fee2rate(signedTxA.tx.fees, signedTx.weight())
assertTrue(targetFeerate <= feerate && feerate <= targetFeerate * 1.25, "unexpected feerate (target=$targetFeerate actual=$feerate)")
}
Expand Down Expand Up @@ -162,7 +162,7 @@ class InteractiveTxTestsCommon : LightningTestSuite() {
assertEquals(signedTx.lockTime, 0)
assertEquals(signedTx.txIn.size, 2)
assertEquals(signedTx.txOut.size, 3)
Transaction.correctlySpends(signedTx, (sharedTxA.sharedTx.localInputs + sharedTxB.sharedTx.localInputs).map { it.previousTx }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
Transaction.correctlySpends(signedTx, (sharedTxA.sharedTx.localInputs + sharedTxB.sharedTx.localInputs).associate { it.outPoint to it.txOut }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
val feerate = Transactions.fee2rate(signedTxB.tx.fees, signedTx.weight())
assertTrue(targetFeerate <= feerate && feerate <= targetFeerate * 1.25, "unexpected feerate (target=$targetFeerate actual=$feerate)")
}
Expand Down Expand Up @@ -214,7 +214,7 @@ class InteractiveTxTestsCommon : LightningTestSuite() {
// The resulting transaction is valid and has the right feerate.
val signedTxB = sharedTxB.sharedTx.sign(bob3, f.keyManagerB, f.fundingParamsB, f.localParamsB, f.localParamsA.nodeId).addRemoteSigs(f.channelKeysB, f.fundingParamsB, signedTxA.localSigs)
assertNotNull(signedTxB)
Transaction.correctlySpends(signedTxB.signedTx, (sharedTxA.sharedTx.localInputs + sharedTxB.sharedTx.localInputs).map { it.previousTx }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
Transaction.correctlySpends(signedTxB.signedTx, (sharedTxA.sharedTx.localInputs + sharedTxB.sharedTx.localInputs).associate { it.outPoint to it.txOut }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
val feerate = Transactions.fee2rate(signedTxB.tx.fees, signedTxB.signedTx.weight())
assertTrue(targetFeerate <= feerate && feerate <= targetFeerate * 1.25, "unexpected feerate (target=$targetFeerate actual=$feerate)")
}
Expand Down Expand Up @@ -279,7 +279,7 @@ class InteractiveTxTestsCommon : LightningTestSuite() {
assertEquals(signedTx.lockTime, 0)
assertEquals(signedTx.txIn.size, 2)
assertEquals(signedTx.txOut.size, 2)
Transaction.correctlySpends(signedTx, sharedTxA.sharedTx.localInputs.map { it.previousTx }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
Transaction.correctlySpends(signedTx, sharedTxA.sharedTx.localInputs.associate { it.outPoint to it.txOut }, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
val feerate = Transactions.fee2rate(signedTxA.tx.fees, signedTx.weight())
assertTrue(targetFeerate <= feerate && feerate <= targetFeerate * 1.25, "unexpected feerate (target=$targetFeerate actual=$feerate)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ class WaitForFundingConfirmedTestsCommon : LightningTestSuite() {
assertIs<RbfStatus.InProgress>(bob1.state.rbfStatus)
assertEquals(actions1.size, 1)
actions1.hasOutgoingMessage<TxAckRbf>()
val txAddInput = alice.state.latestFundingTx.sharedTx.tx.localInputs.first().run { TxAddInput(alice.channelId, serialId, previousTx, previousTxOutput, sequence) }
val input = alice.state.latestFundingTx.sharedTx.tx.localInputs.first()
val tlvs = when (input) {
is InteractiveTxInput.LocalSwapIn -> TlvStream<TxAddInputTlv>(input.swapInParams)
is InteractiveTxInput.LocalLegacySwapIn -> TlvStream<TxAddInputTlv>(TxAddInputTlv.SwapInParamsLegacy(input.userKey, input.serverKey, input.refundDelay))
is InteractiveTxInput.LocalOnly -> TlvStream.empty()
}
val txAddInput = input.run { TxAddInput(alice.channelId, serialId, previousTx, previousTxOutput, sequence, tlvs) }
val (bob2, actions2) = bob1.process(ChannelCommand.MessageReceived(txAddInput))
assertEquals(actions2.size, 1)
actions2.hasOutgoingMessage<TxAddInput>()
Expand Down

0 comments on commit c0a6d5a

Please sign in to comment.