Skip to content

Commit

Permalink
Add payer note (#670)
Browse files Browse the repository at this point in the history
Allows adding an optional payer note when paying an offer and save the payer note on the recipient's side.
  • Loading branch information
thomash-acinq authored Jun 18, 2024
1 parent c792e0f commit 4897222
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 51 deletions.
4 changes: 2 additions & 2 deletions src/commonMain/kotlin/fr/acinq/lightning/NodeParams.kt
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ data class NodeParams(
val liquidityPolicy: MutableStateFlow<LiquidityPolicy>,
val minFinalCltvExpiryDelta: CltvExpiryDelta,
val maxFinalCltvExpiryDelta: CltvExpiryDelta,
val bolt12invoiceExpiry: Duration
val bolt12invoiceExpiry: Duration,
) {
val nodePrivateKey get() = keyManager.nodeKeys.nodeKey.privateKey
val nodeId get() = keyManager.nodeKeys.nodeKey.publicKey
Expand Down Expand Up @@ -232,7 +232,7 @@ data class NodeParams(
liquidityPolicy = MutableStateFlow<LiquidityPolicy>(LiquidityPolicy.Auto(maxAbsoluteFee = 2_000.sat, maxRelativeFeeBasisPoints = 3_000 /* 3000 = 30 % */, skipAbsoluteFeeCheck = false)),
minFinalCltvExpiryDelta = Bolt11Invoice.DEFAULT_MIN_FINAL_EXPIRY_DELTA,
maxFinalCltvExpiryDelta = CltvExpiryDelta(360),
bolt12invoiceExpiry = 60.seconds
bolt12invoiceExpiry = 60.seconds,
)

/**
Expand Down
6 changes: 3 additions & 3 deletions src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ data class PayInvoice(override val paymentId: UUID, override val amount: MilliSa
val paymentHash: ByteVector32 = paymentDetails.paymentHash
val recipient: PublicKey = paymentDetails.paymentRequest.nodeId
}
data class PayOffer(override val paymentId: UUID, val payerKey: PrivateKey, override val amount: MilliSatoshi, val offer: OfferTypes.Offer, val fetchInvoiceTimeout: Duration, val trampolineFeesOverride: List<TrampolineFees>? = null) : SendPayment()
data class PayOffer(override val paymentId: UUID, val payerKey: PrivateKey, val payerNote: String?, override val amount: MilliSatoshi, val offer: OfferTypes.Offer, val fetchInvoiceTimeout: Duration, val trampolineFeesOverride: List<TrampolineFees>? = null) : SendPayment()
// @formatter:on

data class PurgeExpiredPayments(val fromCreatedAt: Long, val toCreatedAt: Long) : PaymentCommand()
Expand Down Expand Up @@ -662,7 +662,7 @@ class Peer(
return res.await()
}

suspend fun payOffer(amount: MilliSatoshi, offer: OfferTypes.Offer, payerKey: PrivateKey, fetchInvoiceTimeout: Duration): SendPaymentResult {
suspend fun payOffer(amount: MilliSatoshi, offer: OfferTypes.Offer, payerKey: PrivateKey, payerNote: String?, fetchInvoiceTimeout: Duration): SendPaymentResult {
val res = CompletableDeferred<SendPaymentResult>()
val paymentId = UUID.randomUUID()
this.launch {
Expand All @@ -672,7 +672,7 @@ class Peer(
.first()
)
}
send(PayOffer(paymentId, payerKey, amount, offer, fetchInvoiceTimeout))
send(PayOffer(paymentId, payerKey, payerNote, amount, offer, fetchInvoiceTimeout))
return res.await()
}

Expand Down
11 changes: 9 additions & 2 deletions src/commonMain/kotlin/fr/acinq/lightning/payment/OfferManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, v
* @return invoice requests that must be sent and the corresponding path_id that must be used in case of a timeout.
*/
fun requestInvoice(payOffer: PayOffer): Triple<ByteVector32, List<OnionMessage>, OfferTypes.InvoiceRequest> {
val request = OfferTypes.InvoiceRequest(payOffer.offer, payOffer.amount, 1, nodeParams.features.bolt12Features(), payOffer.payerKey, nodeParams.chainHash)
val request = OfferTypes.InvoiceRequest(payOffer.offer, payOffer.amount, 1, nodeParams.features.bolt12Features(), payOffer.payerKey, payOffer.payerNote, nodeParams.chainHash)
val replyPathId = randomBytes32()
pendingInvoiceRequests[replyPathId] = PendingInvoiceRequest(payOffer, request)
// We add dummy hops to the reply path: this way the receiver only learns that we're at most 3 hops away from our peer.
Expand Down Expand Up @@ -150,7 +150,14 @@ class OfferManager(val nodeParams: NodeParams, val walletParams: WalletParams, v
else -> {
val amount = request.amount ?: (request.offer.amount!! * request.quantity)
val preimage = randomBytes32()
val pathId = OfferPaymentMetadata.V1(ByteVector32(decrypted.pathId), amount, preimage, request.payerId, request.quantity, currentTimestampMillis()).toPathId(nodeParams.nodePrivateKey)
val truncatedPayerNote = request.payerNote?.let {
if (it.length <= 64) {
it
} else {
it.take(63) + ""
}
}
val pathId = OfferPaymentMetadata.V1(ByteVector32(decrypted.pathId), amount, preimage, request.payerId, truncatedPayerNote, request.quantity, currentTimestampMillis()).toPathId(nodeParams.nodePrivateKey)
val recipientPayload = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(pathId))).write().toByteVector()
val paymentInfo = OfferTypes.PaymentInfo(
feeBase = remoteChannelUpdates.maxOfOrNull { it.feeBaseMsat } ?: walletParams.invoiceDefaultRoutingFees.feeBase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ sealed class OfferPaymentMetadata {
override val amount: MilliSatoshi,
override val preimage: ByteVector32,
val payerKey: PublicKey,
val payerNote: String?,
val quantity: Long,
override val createdAtMillis: Long
) : OfferPaymentMetadata() {
Expand All @@ -68,17 +69,20 @@ sealed class OfferPaymentMetadata {
LightningCodecs.writeBytes(payerKey.value, out)
LightningCodecs.writeU64(quantity, out)
LightningCodecs.writeU64(createdAtMillis, out)
payerNote?.let { LightningCodecs.writeBytes(it.encodeToByteArray(), out) }
}

companion object {
fun read(input: Input): V1 = V1(
offerId = LightningCodecs.bytes(input, 32).byteVector32(),
amount = LightningCodecs.u64(input).msat,
preimage = LightningCodecs.bytes(input, 32).byteVector32(),
payerKey = PublicKey(LightningCodecs.bytes(input, 33)),
quantity = LightningCodecs.u64(input),
createdAtMillis = LightningCodecs.u64(input),
)
fun read(input: Input): V1 {
val offerId = LightningCodecs.bytes(input, 32).byteVector32()
val amount = LightningCodecs.u64(input).msat
val preimage = LightningCodecs.bytes(input, 32).byteVector32()
val payerKey = PublicKey(LightningCodecs.bytes(input, 33))
val quantity = LightningCodecs.u64(input)
val createdAtMillis = LightningCodecs.u64(input)
val payerNote = if (input.availableBytes > 0) LightningCodecs.bytes(input, input.availableBytes).decodeToString() else null
return V1(offerId, amount, preimage, payerKey, payerNote, quantity, createdAtMillis)
}
}
}

Expand All @@ -104,11 +108,12 @@ sealed class OfferPaymentMetadata {
val input = ByteArrayInput(pathId.toByteArray())
when (LightningCodecs.byte(input)) {
1 -> {
if (input.availableBytes != 185) return null
val metadata = LightningCodecs.bytes(input, 121)
if (input.availableBytes < 185) return null
val metadataSize = input.availableBytes - 64
val metadata = LightningCodecs.bytes(input, metadataSize)
val signature = LightningCodecs.bytes(input, 64).byteVector64()
// Note that the signature includes the version byte.
if (!Crypto.verifySignature(Crypto.sha256(pathId.take(122)), signature, nodeId)) return null
if (!Crypto.verifySignature(Crypto.sha256(pathId.take(1 + metadataSize)), signature, nodeId)) return null
// This call is safe since we verified that we have the right number of bytes and the signature was valid.
return V1.read(ByteArrayInput(metadata))
}
Expand Down
2 changes: 2 additions & 0 deletions src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ object OfferTypes {
quantity: Long,
features: Features,
payerKey: PrivateKey,
payerNote: String?,
chain: BlockHash,
additionalTlvs: Set<InvoiceRequestTlv> = setOf(),
customTlvs: Set<GenericTlv> = setOf()
Expand All @@ -906,6 +907,7 @@ object OfferTypes {
if (offer.quantityMax != null) InvoiceRequestQuantity(quantity) else null,
if (features != Features.empty) InvoiceRequestFeatures(features) else null,
InvoiceRequestPayerId(payerKey.publicKey()),
payerNote?.let { InvoiceRequestPayerNote(it) },
) + additionalTlvs
val signature = signSchnorr(
signatureTag,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val payerKey = randomKey()
val chain = BlockHash(randomBytes32())
val offer = Offer.createNonBlindedOffer(10000.msat, "test offer", nodeKey.publicKey(), Features.empty, chain)
val request = InvoiceRequest(offer, 11000.msat, 1, Features.empty, payerKey, chain)
val request = InvoiceRequest(offer, 11000.msat, 1, Features.empty, payerKey, null, chain)
val invoice = Bolt12Invoice(
request,
randomBytes32(),
Expand Down Expand Up @@ -122,7 +122,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val payerKey = randomKey()
val chain = BlockHash(randomBytes32())
val offer = Offer.createNonBlindedOffer(10000.msat, "test offer", nodeKey.publicKey(), Features.empty, chain)
val basicRequest = InvoiceRequest(offer, 11000.msat, 1, Features.empty, payerKey, chain)
val basicRequest = InvoiceRequest(offer, 11000.msat, 1, Features.empty, payerKey, null, chain)
val requestWithUnknownTlv = basicRequest.copy(records = TlvStream(basicRequest.records.records, setOf(GenericTlv(87, ByteVector.fromHex("0404")))))
val invoice = Bolt12Invoice(
requestWithUnknownTlv,
Expand All @@ -143,7 +143,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val payerKey = randomKey()
val chain = BlockHash(randomBytes32())
val offer = Offer.createNonBlindedOffer(10000.msat, "test offer", nodeKey.publicKey(), Features.empty, chain)
val request = InvoiceRequest(offer, 11000.msat, 1, Features.empty, payerKey, chain)
val request = InvoiceRequest(offer, 11000.msat, 1, Features.empty, payerKey, null, chain)
val invoice = Bolt12Invoice(
request,
randomBytes32(),
Expand Down Expand Up @@ -189,7 +189,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val payerKey = randomKey()
val chain = BlockHash(randomBytes32())
val offer = Offer.createNonBlindedOffer(15000.msat, "test offer", nodeKey.publicKey(), Features.empty, chain)
val request = InvoiceRequest(offer, 15000.msat, 1, Features.empty, payerKey, chain)
val request = InvoiceRequest(offer, 15000.msat, 1, Features.empty, payerKey, null, chain)
assertTrue(request.quantity_opt == null) // when paying for a single item, the quantity field must not be present
val invoice = Bolt12Invoice(
request,
Expand Down Expand Up @@ -271,7 +271,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val payerKey = randomKey()
val chain = BlockHash(randomBytes32())
val offer = Offer.createNonBlindedOffer(5000.msat, "test offer", nodeKey.publicKey(), Features.empty, chain)
val request = InvoiceRequest(offer, 5000.msat, 1, Features.empty, payerKey, chain)
val request = InvoiceRequest(offer, 5000.msat, 1, Features.empty, payerKey, null, chain)
val invoice = Bolt12Invoice(
request,
randomBytes32(),
Expand Down Expand Up @@ -410,7 +410,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val encodedOffer = "lno1pg9k66twd9kkzmpqw35hq93pqf8l2vtlq5w87m4vqfnvtn82adk9wadfgratnp2wg7l7ha4u0gzqw"
assertEquals(offer.toString(), encodedOffer)
assertEquals(Offer.decode(encodedOffer).get(), offer)
val request = InvoiceRequest(offer, 12000000.msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash)
val request = InvoiceRequest(offer, 12000000.msat, 1, Features.empty, payerKey, null, Block.LivenetGenesisBlock.hash)
// Invoice request generation is not reproducible because we add randomness in the first TLV.
val encodedRequest = "lnr1qqs289chx8swkpmwf3uzexfxr0kk9syavsjcmkuur5qgjqt60ayjdec2pdkkjmnfd4skcgr5d9cpvggzfl6nzlc9r3lkatqzvmzue6htd3tht22ql2uc2nj8hl4ld0r6qsr4qgr0u2xq4dh3kdevrf4zg6hx8a60jv0gxe0ptgyfc6xkryqqqqqqqpfq8dcmqpvzzqc773pe7cufzn08jgsys0w6xt0m0fp3u7v6tnj6weplh4ctyyvwfmcypemfjk6kryqxycnnmu2vp9tuw00eslf0grp6rf3hk6v76aynyn4lclra0fyyk2gxyf9hx73rnm775204tn8cltacw4s0fzd5c0lxm58s"
val decodedRequest = InvoiceRequest.decode(encodedRequest).get()
Expand Down Expand Up @@ -448,7 +448,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val encodedOffer = "lno1pqzpktszqq9q6mtfde5k6ctvyphkven9wgtzzq7y3tyhuz0newawkdds924x6pet2aexssdrf5je2g2het9xpgw275"
assertEquals(offer.toString(), encodedOffer)
assertEquals(Offer.decode(encodedOffer).get(), offer)
val request = InvoiceRequest(offer, 456001234.msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash)
val request = InvoiceRequest(offer, 456001234.msat, 1, Features.empty, payerKey, null, Block.LivenetGenesisBlock.hash)
// Invoice request generation is not reproducible because we add randomness in the first TLV.
val encodedRequest = "lnr1qqsf4h8fsnpjkj057gjg9c3eqhv889440xh0z6f5kng9vsaad8pgq7sgqsdjuqsqpgxk66twd9kkzmpqdanxvetjzcss83y2e9lqnu7tht4ntvp24fksw26hwf5yrg6dyk2jz472efs2rjh42qsxlc5vp2m0rvmjcxn2y34wv0m5lyc7sdj7zksgn35dvxgqqqqqqqzjqsdjupkjtqssx05572ha26x39rczan5yft22pgwa72jw8gytavkm5ydn7yf5kpgh7pq2hlvh7twke5830a44wc0zlrs2kph4ghndm60ahwcznhcd0pcpl332qv5xuemksazy3zx5s63kqmqkphrn9jg4ln55pc6syrwqukejeq"
val decodedRequest = InvoiceRequest.decode(encodedRequest).get()
Expand Down Expand Up @@ -495,7 +495,7 @@ class Bolt12InvoiceTestsCommon : LightningTestSuite() {
val encodedOffer = "lno1qgsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqgqvqcdgq2zdhkven9wgs8w6t5dqs8zatpde6xjarezggkzmrfvdj5qcnfvaeksmms9e3k7mg5qgp7s93pqvn6l4vemgezdarq3wt2kpp0u4vt74vzz8futen7ej97n93jypp57"
assertEquals(offer.toString(), encodedOffer)
assertEquals(Offer.decode(encodedOffer).get(), offer)
val request = InvoiceRequest(offer, 7200000.msat, 72, Features.empty, payerKey, Block.TestnetGenesisBlock.hash)
val request = InvoiceRequest(offer, 7200000.msat, 72, Features.empty, payerKey, null, Block.TestnetGenesisBlock.hash)
// Invoice request generation is not reproducible because we add randomness in the first TLV.
val encodedRequest = "lnr1qqs8lqvnh3kg9uj003lxlxyj8hthymgq4p9ms0ag0ryx5uw8gsuus4gzypp5jl7hlqnf2ugg7j3slkwwcwht57vhyzzwjr4dq84rxzgqqqqqqzqrqxr2qzsndanxvetjypmkjargypch2ctww35hg7gjz9skc6trv4qxy6t8wd5x7upwvdhk69qzq05pvggry7hatxw6xgn0gcytj64sgtl9tzl4tqs360z7vlkv305evv3qgd84qgzrf9la07pxj4cs3a9rplvuasawhfuewgyyay826q02xvysqqqqqpfqxmwaqptqzjzcyyp8cmgrl28nvm3wlqqheha0t570rgaszg7mzvvzvwmx9s92nmyujk0sgpef8dt57nygu3dnfhglymt6mnle6j8s28rler8wv3zygen07v4ddfplc9qs7nkdzwcelm2rs552slkpv45xxng65ne6y4dlq2764gqv"
val decodedRequest = InvoiceRequest.decode(encodedRequest).get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() {
fun `reject blinded payment with amount too low`() = runSuspendTest {
val paymentHandler = IncomingPaymentHandler(TestConstants.Bob.nodeParams, InMemoryPaymentsDb())
val cltvExpiry = TestConstants.Bob.nodeParams.minFinalCltvExpiryDelta.toCltvExpiry(TestConstants.defaultBlockHeight.toLong())
val metadata = OfferPaymentMetadata.V1(randomBytes32(), 100_000_000.msat, randomBytes32(), randomKey().publicKey(), 1, currentTimestampMillis())
val metadata = OfferPaymentMetadata.V1(randomBytes32(), 100_000_000.msat, randomBytes32(), randomKey().publicKey(), null, 1, currentTimestampMillis())
val pathId = metadata.toPathId(TestConstants.Bob.nodeParams.nodePrivateKey)
val amountTooLow = metadata.amount - 10_000_000.msat
val (finalPayload, route) = makeBlindedPayload(TestConstants.Bob.nodeParams.nodeId, amountTooLow, amountTooLow, cltvExpiry, pathId)
Expand All @@ -1277,7 +1277,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() {
fun `reject blinded payment with payment_hash mismatch`() = runSuspendTest {
val paymentHandler = IncomingPaymentHandler(TestConstants.Bob.nodeParams, InMemoryPaymentsDb())
val cltvExpiry = TestConstants.Bob.nodeParams.minFinalCltvExpiryDelta.toCltvExpiry(TestConstants.defaultBlockHeight.toLong())
val metadata = OfferPaymentMetadata.V1(randomBytes32(), 100_000_000.msat, randomBytes32(), randomKey().publicKey(), 1, currentTimestampMillis())
val metadata = OfferPaymentMetadata.V1(randomBytes32(), 100_000_000.msat, randomBytes32(), randomKey().publicKey(), null, 1, currentTimestampMillis())
val pathId = metadata.toPathId(TestConstants.Bob.nodeParams.nodePrivateKey)
val (finalPayload, route) = makeBlindedPayload(TestConstants.Bob.nodeParams.nodeId, metadata.amount, metadata.amount, cltvExpiry, pathId)
val add = makeUpdateAddHtlc(8, randomBytes32(), paymentHandler, metadata.paymentHash.reversed(), finalPayload, route.blindingKey)
Expand Down Expand Up @@ -1357,7 +1357,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() {
preimage: ByteVector32 = randomBytes32(),
payerKey: PublicKey = randomKey().publicKey()
): Pair<PaymentOnion.FinalPayload.Blinded, RouteBlinding.BlindedRoute> {
val pathId = OfferPaymentMetadata.V1(offerId, totalAmount, preimage, payerKey, quantity, currentTimestampMillis()).toPathId(TestConstants.Bob.nodeParams.nodePrivateKey)
val pathId = OfferPaymentMetadata.V1(offerId, totalAmount, preimage, payerKey, null, quantity, currentTimestampMillis()).toPathId(TestConstants.Bob.nodeParams.nodePrivateKey)
val recipientData = RouteBlindingEncryptedData(TlvStream(RouteBlindingEncryptedDataTlv.PathId(pathId)))
val route = RouteBlinding.create(randomKey(), listOf(recipientNodeId), listOf(recipientData.write().toByteVector())).route
val payload = PaymentOnion.FinalPayload.Blinded(
Expand Down
Loading

0 comments on commit 4897222

Please sign in to comment.