diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/MessageOnion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/MessageOnion.kt index ad778794d..0e9516995 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/MessageOnion.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/MessageOnion.kt @@ -6,7 +6,6 @@ import fr.acinq.bitcoin.io.ByteArrayInput import fr.acinq.bitcoin.io.ByteArrayOutput import fr.acinq.bitcoin.io.Input import fr.acinq.bitcoin.io.Output -import fr.acinq.lightning.EncodedNodeId import fr.acinq.lightning.crypto.RouteBlinding sealed class OnionMessagePayloadTlv : Tlv { @@ -58,6 +57,62 @@ sealed class OnionMessagePayloadTlv : Tlv { EncryptedData(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) } } + + /** + * In order to pay a Bolt 12 offer, we must send an onion message to request an invoice corresponding to that offer. + * The creator of the offer will send us an invoice back through our blinded reply path. + */ + data class InvoiceRequest(val tlvs: TlvStream) : OnionMessagePayloadTlv() { + override val tag: Long get() = InvoiceRequest.tag + override fun write(out: Output) = OfferTypes.InvoiceRequest.tlvSerializer.write(tlvs, out) + + companion object : TlvValueReader { + const val tag: Long = 64 + + override fun read(input: Input): InvoiceRequest = + InvoiceRequest(OfferTypes.InvoiceRequest.tlvSerializer.read(input)) + } + } + + /** + * When receiving an invoice request, we must send an onion message back containing an invoice corresponding to the + * requested offer (if it was an offer we published). + */ + data class Invoice(val tlvs: TlvStream) : OnionMessagePayloadTlv() { + override val tag: Long get() = Invoice.tag + override fun write(out: Output) = OfferTypes.Invoice.tlvSerializer.write(tlvs, out) + + companion object : TlvValueReader { + const val tag: Long = 66 + + override fun read(input: Input): Invoice = + Invoice(OfferTypes.Invoice.tlvSerializer.read(input)) + } + } + + /** + * This message may be used when we receive an invalid invoice or invoice request. + * It contains information helping senders figure out why their message was invalid. + */ + data class InvoiceError(val tlvs: TlvStream) : OnionMessagePayloadTlv() { + override val tag: Long get() = InvoiceError.tag + override fun write(out: Output) = tlvSerializer.write(tlvs, out) + + companion object : TlvValueReader { + const val tag: Long = 68 + + val tlvSerializer = TlvStreamSerializer( + true, @Suppress("UNCHECKED_CAST") mapOf( + OfferTypes.ErroneousField.tag to OfferTypes.ErroneousField.Companion as TlvValueReader, + OfferTypes.SuggestedValue.tag to OfferTypes.SuggestedValue.Companion as TlvValueReader, + OfferTypes.Error.tag to OfferTypes.Error.Companion as TlvValueReader, + ) + ) + + override fun read(input: Input): InvoiceError = + InvoiceError(tlvSerializer.read(input)) + } + } } data class MessageOnion(val records: TlvStream) { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt new file mode 100644 index 000000000..85b9189b6 --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/OfferTypes.kt @@ -0,0 +1,1038 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.io.ByteArrayOutput +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output +import fr.acinq.bitcoin.utils.Either +import fr.acinq.bitcoin.utils.Either.Left +import fr.acinq.bitcoin.utils.Either.Right +import fr.acinq.bitcoin.utils.Try +import fr.acinq.bitcoin.utils.runTrying +import fr.acinq.lightning.CltvExpiryDelta +import fr.acinq.lightning.Features +import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.MilliSatoshi +import fr.acinq.lightning.crypto.RouteBlinding + +/** + * Lightning Bolt 12 offers + * see https://github.com/lightning/bolts/blob/master/12-offer-encoding.md + */ +object OfferTypes { + /** Data provided to reach the issuer of an offer or invoice. */ + sealed class ContactInfo { + /** If the offer or invoice issuer doesn't want to hide their identity, they can directly share their public nodeId. */ + data class RecipientNodeId(val nodeId: PublicKey) : ContactInfo() + + /** If the offer or invoice issuer wants to hide their identity, they instead provide blinded paths. */ + data class BlindedPath(val route: RouteBlinding.BlindedRoute) : ContactInfo() + } + + fun writePath(path: ContactInfo.BlindedPath, out: Output) { + LightningCodecs.writeEncodedNodeId(path.route.introductionNodeId, out) + LightningCodecs.writeBytes(path.route.blindingKey.value, out) + LightningCodecs.writeByte(path.route.blindedNodes.size, out) + for (node in path.route.blindedNodes) { + LightningCodecs.writeBytes(node.blindedPublicKey.value, out) + LightningCodecs.writeU16(node.encryptedPayload.size(), out) + LightningCodecs.writeBytes(node.encryptedPayload, out) + } + } + + fun readPath(input: Input): ContactInfo.BlindedPath { + val introductionNodeId = LightningCodecs.encodedNodeId(input) + val blindingKey = PublicKey(LightningCodecs.bytes(input, 33)) + val blindedNodes = ArrayList() + val numBlindedNodes = LightningCodecs.byte(input) + for (i in 1 .. numBlindedNodes) { + val blindedKey = PublicKey(LightningCodecs.bytes(input, 33)) + val payload = ByteVector(LightningCodecs.bytes(input, LightningCodecs.u16(input))) + blindedNodes.add(RouteBlinding.BlindedNode(blindedKey, payload)) + } + return ContactInfo.BlindedPath(RouteBlinding.BlindedRoute(introductionNodeId, blindingKey, blindedNodes)) + } + + sealed class Bolt12Tlv : Tlv + + sealed class InvoiceTlv : Bolt12Tlv() + + sealed class InvoiceRequestTlv : InvoiceTlv() + + sealed class OfferTlv : InvoiceRequestTlv() + + sealed class InvoiceErrorTlv : Bolt12Tlv() + + /** + * Chains for which the offer is valid. If empty, bitcoin mainnet is implied. + */ + data class OfferChains(val chains: List) : OfferTlv() { + override val tag: Long get() = OfferChains.tag + + override fun write(out: Output) { + for (chain in chains) { + LightningCodecs.writeBytes(chain.value, out) + } + } + + companion object : TlvValueReader { + const val tag: Long = 2 + override fun read(input: Input): OfferChains { + val chains = ArrayList() + while (input.availableBytes > 0) { + chains.add(BlockHash(LightningCodecs.bytes(input, 32))) + } + return OfferChains(chains) + } + } + } + + /** + * Data from the offer creator to themselves, for instance a signature that authenticates the offer so that they don't need to store the offer. + */ + data class OfferMetadata(val data: ByteVector) : OfferTlv() { + override val tag: Long get() = OfferMetadata.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(data, out) + } + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): OfferMetadata { + return OfferMetadata(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Three-letter code of the currency the offer is denominated in. If empty, bitcoin is implied. + */ + data class OfferCurrency(val iso4217: String) : OfferTlv() { + override val tag: Long get() = OfferCurrency.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(iso4217.encodeToByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 6 + override fun read(input: Input): OfferCurrency { + return OfferCurrency(LightningCodecs.bytes(input, input.availableBytes).decodeToString()) + } + } + } + + /** + * Amount to pay per item. As we only support bitcoin, the amount is in msat. + */ + data class OfferAmount(val amount: MilliSatoshi) : OfferTlv() { + override val tag: Long get() = OfferAmount.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(amount.toLong(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 8 + override fun read(input: Input): OfferAmount { + return OfferAmount(MilliSatoshi(LightningCodecs.tu64(input))) + } + } + } + + /** + * Description of the purpose of the payment. + */ + data class OfferDescription(val description: String) : OfferTlv() { + override val tag: Long get() = OfferDescription.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(description.encodeToByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 10 + override fun read(input: Input): OfferDescription { + return OfferDescription(LightningCodecs.bytes(input, input.availableBytes).decodeToString()) + } + } + } + + /** + * Features supported to pay the offer. + */ + data class OfferFeatures(val features: Features) : OfferTlv() { + override val tag: Long get() = OfferFeatures.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(features.toByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 12 + override fun read(input: Input): OfferFeatures { + return OfferFeatures(Features(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Time after which the offer is no longer valid. + */ + data class OfferAbsoluteExpiry(val absoluteExpirySeconds: Long) : OfferTlv() { + override val tag: Long get() = OfferAbsoluteExpiry.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(absoluteExpirySeconds, out) + } + + companion object : TlvValueReader { + const val tag: Long = 14 + override fun read(input: Input): OfferAbsoluteExpiry { + return OfferAbsoluteExpiry(LightningCodecs.tu64(input)) + } + } + } + + /** + * Paths that can be used to retrieve an invoice. + */ + data class OfferPaths(val paths: List) : OfferTlv() { + override val tag: Long get() = OfferPaths.tag + + override fun write(out: Output) { + for(path in paths){ + writePath(path, out) + } + } + + companion object : TlvValueReader { + const val tag: Long = 16 + override fun read(input: Input): OfferPaths { + val paths = ArrayList() + while (input.availableBytes > 0) { + val path = readPath(input) + paths.add(path) + } + return OfferPaths(paths) + } + } + } + + /** + * Name of the offer creator. + */ + data class OfferIssuer(val issuer: String) : OfferTlv() { + override val tag: Long get() = OfferIssuer.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(issuer.encodeToByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 18 + override fun read(input: Input): OfferIssuer { + return OfferIssuer(LightningCodecs.bytes(input, input.availableBytes).decodeToString()) + } + } + } + + /** + * If present, the item described in the offer can be purchased multiple times with a single payment. + * If max = 0, there is no limit on the quantity that can be purchased in a single payment. + * If max > 1, it corresponds to the maximum number of items that be purchased in a single payment. + */ + data class OfferQuantityMax(val max: Long) : OfferTlv() { + override val tag: Long get() = OfferQuantityMax.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(max, out) + } + + companion object : TlvValueReader { + const val tag: Long = 20 + override fun read(input: Input): OfferQuantityMax { + return OfferQuantityMax(LightningCodecs.tu64(input)) + } + } + } + + /** + * Public key of the offer creator. + * If `OfferPaths` is present, they must be used to retrieve an invoice even if this public key corresponds to a node id in the public network. + * If `OfferPaths` is not present, this public key must correspond to a node id in the public network that needs to be contacted to retrieve an invoice. + */ + data class OfferNodeId(val publicKey: PublicKey) : OfferTlv() { + override val tag: Long get() = OfferNodeId.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(publicKey.value, out) + } + + companion object : TlvValueReader { + const val tag: Long = 22 + override fun read(input: Input): OfferNodeId { + return OfferNodeId(PublicKey(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Random data to provide enough entropy so that some fields of the invoice request / invoice can be revealed without revealing the others. + */ + data class InvoiceRequestMetadata(val data: ByteVector) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestMetadata.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(data, out) + } + + companion object : TlvValueReader { + const val tag: Long = 0 + override fun read(input: Input): InvoiceRequestMetadata { + return InvoiceRequestMetadata(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * If `OfferChains` is present, this specifies which chain is going to be used to pay. + */ + data class InvoiceRequestChain(val hash: BlockHash) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestChain.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(hash.value, out) + } + + companion object : TlvValueReader { + const val tag: Long = 80 + override fun read(input: Input): InvoiceRequestChain { + return InvoiceRequestChain(BlockHash(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Amount that the sender is going to send. + */ + data class InvoiceRequestAmount(val amount: MilliSatoshi) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestAmount.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(amount.toLong(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 82 + override fun read(input: Input): InvoiceRequestAmount { + return InvoiceRequestAmount(MilliSatoshi(LightningCodecs.tu64(input))) + } + } + } + + /** + * Features supported by the sender to pay the offer. + */ + data class InvoiceRequestFeatures(val features: Features) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestFeatures.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(features.toByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 84 + override fun read(input: Input): InvoiceRequestFeatures { + return InvoiceRequestFeatures(Features(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Number of items to purchase. Only use if the offer supports purchasing multiple items at once. + */ + data class InvoiceRequestQuantity(val quantity: Long) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestQuantity.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(quantity, out) + } + + companion object : TlvValueReader { + const val tag: Long = 86 + override fun read(input: Input): InvoiceRequestQuantity { + return InvoiceRequestQuantity(LightningCodecs.tu64(input)) + } + } + } + + /** + * A public key for which the sender knows the corresponding private key. + * This can be used to prove that you are the sender. + */ + data class InvoiceRequestPayerId(val publicKey: PublicKey) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestPayerId.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(publicKey.value, out) + } + + companion object : TlvValueReader { + const val tag: Long = 88 + override fun read(input: Input): InvoiceRequestPayerId { + return InvoiceRequestPayerId(PublicKey(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * A message from the sender. + */ + data class InvoiceRequestPayerNote(val note: String) : InvoiceRequestTlv() { + override val tag: Long get() = InvoiceRequestPayerNote.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(note.encodeToByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 89 + override fun read(input: Input): InvoiceRequestPayerNote { + return InvoiceRequestPayerNote(LightningCodecs.bytes(input, input.availableBytes).decodeToString()) + } + } + } + + /** + * Payment paths to send the payment to. + */ + data class InvoicePaths(val paths: List) : InvoiceTlv() { + override val tag: Long get() = InvoicePaths.tag + + override fun write(out: Output) { + for(path in paths){ + writePath(path, out) + } + } + + companion object : TlvValueReader { + const val tag: Long = 160 + override fun read(input: Input): InvoicePaths { + val paths = ArrayList() + while (input.availableBytes > 0) { + val path = readPath(input) + paths.add(path) + } + return InvoicePaths(paths) + } + } + } + + data class PaymentInfo(val feeBase: MilliSatoshi, + val feeProportionalMillionths: Int, + val cltvExpiryDelta: CltvExpiryDelta, + val minHtlc: MilliSatoshi, + val maxHtlc: MilliSatoshi, + val allowedFeatures: Features) { + fun fee(amount: MilliSatoshi): MilliSatoshi { + return feeBase + amount * feeProportionalMillionths / 1_000_000L + } + } + + /** + * Costs and parameters of the paths in `InvoicePaths`. + */ + data class InvoiceBlindedPay(val paymentInfos: List) : InvoiceTlv() { + override val tag: Long get() = InvoiceBlindedPay.tag + + override fun write(out: Output) { + for (paymentInfo in paymentInfos) { + LightningCodecs.writeU32(paymentInfo.feeBase.msat.toInt(), out) + LightningCodecs.writeU32(paymentInfo.feeProportionalMillionths, out) + LightningCodecs.writeU16(paymentInfo.cltvExpiryDelta.toInt(), out) + LightningCodecs.writeU64(paymentInfo.minHtlc.msat, out) + LightningCodecs.writeU64(paymentInfo.maxHtlc.msat, out) + val featuresArray = paymentInfo.allowedFeatures.toByteArray() + LightningCodecs.writeU16(featuresArray.size, out) + LightningCodecs.writeBytes(featuresArray, out) + } + } + + companion object : TlvValueReader { + const val tag: Long = 162 + override fun read(input: Input): InvoiceBlindedPay { + val paymentInfos = ArrayList() + while (input.availableBytes > 0) { + val feeBase = MilliSatoshi(LightningCodecs.u32(input).toLong()) + val feeProportionalMillionths = LightningCodecs.u32(input) + val cltvExpiryDelta = CltvExpiryDelta(LightningCodecs.u16(input)) + val minHtlc = MilliSatoshi(LightningCodecs.u64(input)) + val maxHtlc = MilliSatoshi(LightningCodecs.u64(input)) + val allowedFeatures = Features(LightningCodecs.bytes(input, LightningCodecs.u16(input))) + paymentInfos.add(PaymentInfo(feeBase, feeProportionalMillionths, cltvExpiryDelta, minHtlc, maxHtlc, allowedFeatures)) + } + return InvoiceBlindedPay(paymentInfos) + } + } + } + + /** + * Time at which the invoice was created. + */ + data class InvoiceCreatedAt(val timestampSeconds: Long) : InvoiceTlv() { + override val tag: Long get() = InvoiceCreatedAt.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(timestampSeconds, out) + } + + companion object : TlvValueReader { + const val tag: Long = 164 + override fun read(input: Input): InvoiceCreatedAt { + return InvoiceCreatedAt(LightningCodecs.tu64(input)) + } + } + } + + /** + * Duration after which the invoice can no longer be paid. + */ + data class InvoiceRelativeExpiry(val seconds: Long) : InvoiceTlv() { + override val tag: Long get() = InvoiceRelativeExpiry.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(seconds, out) + } + + companion object : TlvValueReader { + const val tag: Long = 166 + override fun read(input: Input): InvoiceRelativeExpiry { + return InvoiceRelativeExpiry(LightningCodecs.tu64(input)) + } + } + } + + /** + * Hash whose preimage will be released in exchange for the payment. + */ + data class InvoicePaymentHash(val hash: ByteVector32) : InvoiceTlv() { + override val tag: Long get() = InvoicePaymentHash.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(hash, out) + } + + companion object : TlvValueReader { + const val tag: Long = 168 + override fun read(input: Input): InvoicePaymentHash { + return InvoicePaymentHash(ByteVector32(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Amount to pay. Must be the same as `InvoiceRequestAmount` if it was present. + */ + data class InvoiceAmount(val amount: MilliSatoshi) : InvoiceTlv() { + override val tag: Long get() = InvoiceAmount.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(amount.toLong(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 170 + override fun read(input: Input): InvoiceAmount { + return InvoiceAmount(MilliSatoshi(LightningCodecs.tu64(input))) + } + } + } + + data class FallbackAddress(val version: Int, val value: ByteVector) + + /** + * Onchain addresses to use to pay the invoice in case the lightning payment fails. + */ + data class InvoiceFallbacks(val addresses: List) : InvoiceTlv() { + override val tag: Long get() = InvoiceFallbacks.tag + + override fun write(out: Output) { + for (address in addresses) { + LightningCodecs.writeByte(address.version, out) + LightningCodecs.writeU16(address.value.size(), out) + LightningCodecs.writeBytes(address.value, out) + } + } + + companion object : TlvValueReader { + const val tag: Long = 172 + override fun read(input: Input): InvoiceFallbacks { + val addresses = ArrayList() + while (input.availableBytes > 0) { + val version = LightningCodecs.byte(input) + val value = ByteVector(LightningCodecs.bytes(input, LightningCodecs.u16(input))) + addresses.add(FallbackAddress(version, value)) + } + return InvoiceFallbacks(addresses) + } + } + } + + /** + * Features supported to pay the invoice. + */ + data class InvoiceFeatures(val features: Features) : InvoiceTlv() { + override val tag: Long get() = InvoiceFeatures.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(features.toByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 174 + override fun read(input: Input): InvoiceFeatures { + return InvoiceFeatures(Features(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Public key of the invoice recipient. + */ + data class InvoiceNodeId(val nodeId: PublicKey) : InvoiceTlv() { + override val tag: Long get() = InvoiceNodeId.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(nodeId.value, out) + } + + companion object : TlvValueReader { + const val tag: Long = 176 + override fun read(input: Input): InvoiceNodeId { + return InvoiceNodeId(PublicKey(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + /** + * Signature from the sender when used in an invoice request. + * Signature from the recipient when used in an invoice. + */ + data class Signature(val signature: ByteVector64) : InvoiceRequestTlv() { + override val tag: Long get() = Signature.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(signature, out) + } + + companion object : TlvValueReader { + const val tag: Long = 240 + override fun read(input: Input): Signature { + return Signature(ByteVector64(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + fun filterOfferFields(tlvs: TlvStream): TlvStream { + // Offer TLVs are in the range (0, 80). + return TlvStream( + tlvs.records.filterIsInstance().toSet(), + tlvs.unknown.filter{it.tag < 80}.toSet() + ) + } + + fun filterInvoiceRequestFields(tlvs: TlvStream): TlvStream { + // Invoice request TLVs are in the range [0, 160): invoice request metadata (tag 0), offer TLVs, and additional invoice request TLVs in the range [80, 160). + return TlvStream( + tlvs.records.filterIsInstance().toSet(), + tlvs.unknown.filter{it.tag < 160}.toSet() + ) + } + + data class ErroneousField(val fieldTag: Long) : InvoiceErrorTlv() { + override val tag: Long get() = ErroneousField.tag + + override fun write(out: Output) { + LightningCodecs.writeTU64(fieldTag, out) + } + + companion object : TlvValueReader { + const val tag: Long = 1 + override fun read(input: Input): ErroneousField { + return ErroneousField(LightningCodecs.tu64(input)) + } + } + } + + data class SuggestedValue(val value: ByteVector) : InvoiceErrorTlv() { + override val tag: Long get() = SuggestedValue.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(value, out) + } + + companion object : TlvValueReader { + const val tag: Long = 3 + override fun read(input: Input): SuggestedValue { + return SuggestedValue(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + } + + data class Error(val message: String) : InvoiceErrorTlv() { + override val tag: Long get() = Error.tag + + override fun write(out: Output) { + LightningCodecs.writeBytes(message.encodeToByteArray(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 5 + override fun read(input: Input): Error { + return Error(LightningCodecs.bytes(input, input.availableBytes).decodeToString()) + } + } + } + + sealed class InvalidTlvPayload { + abstract val tag: Long + } + data class MissingRequiredTlv(override val tag: Long) : InvalidTlvPayload() + data class ForbiddenTlv(override val tag: Long) : InvalidTlvPayload() + + data class Offer(val records: TlvStream) { + val chains: List = records.get()?.chains ?: listOf(Block.LivenetGenesisBlock.hash) + val metadata: ByteVector? = records.get()?.data + val currency: String? = records.get()?.iso4217 + val amount: MilliSatoshi? = if (currency == null) { + records.get()?.amount + } else { + null // TODO: add exchange rates + } + val description: String = records.get()!!.description + val features: Features = records.get()?.features ?: Features.empty + val expirySeconds: Long? = records.get()?.absoluteExpirySeconds + private val paths: List? = records.get()?.paths + val issuer: String? = records.get()?.issuer + val quantityMax: Long? = records.get()?.max?.let { if (it == 0L) Long.MAX_VALUE else it } + val nodeId: PublicKey = records.get()!!.publicKey + + val contactInfos: List = paths ?: listOf(ContactInfo.RecipientNodeId(nodeId)) + + fun encode(): String { + val data = tlvSerializer.write(records) + return Bech32.encodeBytes(hrp, data, Bech32.Encoding.Beck32WithoutChecksum) + } + + override fun toString(): String = encode() + + val offerId: ByteVector32 = rootHash(records) + + companion object { + val hrp = "lno" + + /** + * @param amount_opt amount if it can be determined at offer creation time. + * @param description description of the offer. + * @param nodeId the nodeId to use for this offer, which should be different from our public nodeId if we're hiding behind a blinded route. + * @param features invoice features. + * @param chain chain on which the offer is valid. + */ + operator fun invoke( + amount_opt: MilliSatoshi?, + description: String, + nodeId: PublicKey, + features: Features, + chain: BlockHash, + additionalTlvs: Set = setOf(), + customTlvs: Set = setOf() + ): Offer { + val tlvs: Set = setOfNotNull( + if (chain != Block.LivenetGenesisBlock.hash) OfferChains(listOf(chain)) else null, + amount_opt?.let { OfferAmount(it) }, + OfferDescription(description), + if (features != Features.empty) OfferFeatures(features) else null, + OfferNodeId(nodeId), + ) + additionalTlvs + return Offer(TlvStream(tlvs, customTlvs)) + } + + fun validate(records: TlvStream): Either { + if (records.get() == null) return Left(MissingRequiredTlv(10L)) + if (records.get() == null) return Left(MissingRequiredTlv(22L)) + if (records.unknown.any { it.tag >= 80 }) + return Left(ForbiddenTlv(records.unknown.find{it.tag >= 80}!!.tag)) + return Right(Offer(records)) + } + + val tlvSerializer = TlvStreamSerializer( + false, @Suppress("UNCHECKED_CAST") mapOf( + OfferChains.tag to OfferChains as TlvValueReader, + OfferMetadata.tag to OfferMetadata as TlvValueReader, + OfferCurrency.tag to OfferCurrency as TlvValueReader, + OfferAmount.tag to OfferAmount as TlvValueReader, + OfferDescription.tag to OfferDescription as TlvValueReader, + OfferFeatures.tag to OfferFeatures as TlvValueReader, + OfferAbsoluteExpiry.tag to OfferAbsoluteExpiry as TlvValueReader, + OfferPaths.tag to OfferPaths as TlvValueReader, + OfferIssuer.tag to OfferIssuer as TlvValueReader, + OfferQuantityMax.tag to OfferQuantityMax as TlvValueReader, + OfferNodeId.tag to OfferNodeId as TlvValueReader, + ) + ) + + fun decode(s: String): Try = runTrying { + val (prefix, encoded, encoding) = Bech32.decodeBytes(s.lowercase(), true) + require(prefix == hrp) + require(encoding == Bech32.Encoding.Beck32WithoutChecksum) + val tlvs = tlvSerializer.read(encoded) + when (val offer = validate(tlvs)) { + is Left -> throw IllegalArgumentException(offer.value.toString()) + is Right -> offer.value + } + } + } + } + + data class InvoiceRequest(val records: TlvStream) { + val offer: Offer = Offer.validate(filterOfferFields(records)).right!! + + val metadata: ByteVector = records.get()!!.data + val chain: BlockHash = records.get()?.hash ?: Block.LivenetGenesisBlock.hash + val amount: MilliSatoshi? = records.get()?.amount + val features: Features = records.get()?.features ?: Features.empty + val quantity_opt: Long? = records.get()?.quantity + val quantity: Long = quantity_opt ?: 1 + val payerId: PublicKey = records.get()!!.publicKey + val payerNote: String? = records.get()?.note + private val signature: ByteVector64 = records.get()!!.signature + + fun isValid(): Boolean = + (offer.amount == null || amount == null || offer.amount * quantity <= amount) && + (offer.amount != null || amount != null) && + offer.chains.contains(chain) && + ((offer.quantityMax == null && quantity_opt == null) || (offer.quantityMax != null && quantity_opt != null && quantity <= offer.quantityMax)) && + Features.areCompatible(offer.features, features) && + checkSignature() + + fun checkSignature(): Boolean = + verifySchnorr( + signatureTag, + rootHash(removeSignature(records)), + signature, + payerId + ) + + fun encode(): String { + val data = tlvSerializer.write(records) + return Bech32.encodeBytes(hrp, data, Bech32.Encoding.Beck32WithoutChecksum) + } + + override fun toString(): String = encode() + + fun unsigned(): TlvStream = removeSignature(records) + + companion object { + val hrp = "lnr" + val signatureTag: ByteVector = + ByteVector(("lightning" + "invoice_request" + "signature").encodeToByteArray()) + + /** + * Create a request to fetch an invoice for a given offer. + * + * @param offer Bolt 12 offer. + * @param amount amount that we want to pay. + * @param quantity quantity of items we're buying. + * @param features invoice features. + * @param payerKey private key identifying the payer: this lets us prove we're the ones who paid the invoice. + * @param chain chain we want to use to pay this offer. + */ + operator fun invoke( + offer: Offer, + amount: MilliSatoshi, + quantity: Long, + features: Features, + payerKey: PrivateKey, + chain: BlockHash, + additionalTlvs: Set = setOf(), + customTlvs: Set = setOf() + ): InvoiceRequest { + require(offer.chains.contains(chain)) + require(quantity == 1L || offer.quantityMax != null) + val tlvs: Set = offer.records.records + setOfNotNull( + InvoiceRequestMetadata(randomBytes32()), + InvoiceRequestChain(chain), + InvoiceRequestAmount(amount), + if (offer.quantityMax != null) InvoiceRequestQuantity(quantity) else null, + if (features != Features.empty) InvoiceRequestFeatures(features) else null, + InvoiceRequestPayerId(payerKey.publicKey()), + ) + additionalTlvs + val signature = signSchnorr( + signatureTag, + rootHash(TlvStream(tlvs, offer.records.unknown + customTlvs)), + payerKey + ) + return InvoiceRequest(TlvStream(tlvs + Signature(signature), offer.records.unknown + customTlvs)) + } + + fun validate(records: TlvStream): Either { + when (val offer = Offer.validate(filterOfferFields(records))) { + is Left -> return Left(offer.value) + is Right -> {} + } + if (records.get() == null) return Left(MissingRequiredTlv(0L)) + if (records.get() == null) return Left(MissingRequiredTlv(88)) + if (records.get() == null) return Left(MissingRequiredTlv(240)) + if (records.unknown.any { it.tag >= 160 }) + return Left(ForbiddenTlv(records.unknown.find{ it.tag >= 160 }!!.tag)) + return Right(InvoiceRequest(records)) + } + + val tlvSerializer = TlvStreamSerializer( + false, @Suppress("UNCHECKED_CAST") mapOf( + InvoiceRequestMetadata.tag to InvoiceRequestMetadata as TlvValueReader, + // Offer part that must be copy-pasted from above + OfferChains.tag to OfferChains as TlvValueReader, + OfferMetadata.tag to OfferMetadata as TlvValueReader, + OfferCurrency.tag to OfferCurrency as TlvValueReader, + OfferAmount.tag to OfferAmount as TlvValueReader, + OfferDescription.tag to OfferDescription as TlvValueReader, + OfferFeatures.tag to OfferFeatures as TlvValueReader, + OfferAbsoluteExpiry.tag to OfferAbsoluteExpiry as TlvValueReader, + OfferPaths.tag to OfferPaths as TlvValueReader, + OfferIssuer.tag to OfferIssuer as TlvValueReader, + OfferQuantityMax.tag to OfferQuantityMax as TlvValueReader, + OfferNodeId.tag to OfferNodeId as TlvValueReader, + // Invoice request part + InvoiceRequestChain.tag to InvoiceRequestChain as TlvValueReader, + InvoiceRequestAmount.tag to InvoiceRequestAmount as TlvValueReader, + InvoiceRequestFeatures.tag to InvoiceRequestFeatures as TlvValueReader, + InvoiceRequestQuantity.tag to InvoiceRequestQuantity as TlvValueReader, + InvoiceRequestPayerId.tag to InvoiceRequestPayerId as TlvValueReader, + InvoiceRequestPayerNote.tag to InvoiceRequestPayerNote as TlvValueReader, + Signature.tag to Signature as TlvValueReader, + ) + ) + + fun decode(s: String): Try = runTrying { + val (prefix, encoded, encoding) = Bech32.decodeBytes(s.lowercase(), true) + require(prefix == hrp) + require(encoding == Bech32.Encoding.Beck32WithoutChecksum) + val tlvs = tlvSerializer.read(encoded) + when (val invoiceRequest = validate(tlvs)) { + is Left -> throw IllegalArgumentException(invoiceRequest.value.toString()) + is Right -> invoiceRequest.value + } + } + } + } + + object Invoice { + val tlvSerializer = TlvStreamSerializer( + false, @Suppress("UNCHECKED_CAST") mapOf( + // Invoice request part that must be copy-pasted from above + InvoiceRequestMetadata.tag to InvoiceRequestMetadata as TlvValueReader, + OfferChains.tag to OfferChains as TlvValueReader, + OfferMetadata.tag to OfferMetadata as TlvValueReader, + OfferCurrency.tag to OfferCurrency as TlvValueReader, + OfferAmount.tag to OfferAmount as TlvValueReader, + OfferDescription.tag to OfferDescription as TlvValueReader, + OfferFeatures.tag to OfferFeatures as TlvValueReader, + OfferAbsoluteExpiry.tag to OfferAbsoluteExpiry as TlvValueReader, + OfferPaths.tag to OfferPaths as TlvValueReader, + OfferIssuer.tag to OfferIssuer as TlvValueReader, + OfferQuantityMax.tag to OfferQuantityMax as TlvValueReader, + OfferNodeId.tag to OfferNodeId as TlvValueReader, + InvoiceRequestChain.tag to InvoiceRequestChain as TlvValueReader, + InvoiceRequestAmount.tag to InvoiceRequestAmount as TlvValueReader, + InvoiceRequestFeatures.tag to InvoiceRequestFeatures as TlvValueReader, + InvoiceRequestQuantity.tag to InvoiceRequestQuantity as TlvValueReader, + InvoiceRequestPayerId.tag to InvoiceRequestPayerId as TlvValueReader, + InvoiceRequestPayerNote.tag to InvoiceRequestPayerNote as TlvValueReader, + // Invoice part + InvoicePaths.tag to InvoicePaths as TlvValueReader, + InvoiceBlindedPay.tag to InvoiceBlindedPay as TlvValueReader, + InvoiceCreatedAt.tag to InvoiceCreatedAt as TlvValueReader, + InvoiceRelativeExpiry.tag to InvoiceRelativeExpiry as TlvValueReader, + InvoicePaymentHash.tag to InvoicePaymentHash as TlvValueReader, + InvoiceAmount.tag to InvoiceAmount as TlvValueReader, + InvoiceFallbacks.tag to InvoiceFallbacks as TlvValueReader, + InvoiceFeatures.tag to InvoiceFeatures as TlvValueReader, + InvoiceNodeId.tag to InvoiceNodeId as TlvValueReader, + Signature.tag to Signature as TlvValueReader, + ) + ) + } + + data class InvoiceError(val records: TlvStream) { + val error = records.get()!!.message + + companion object { + fun validate(records: TlvStream): Either { + if (records.get() == null) return Left(MissingRequiredTlv(5)) + return Right(InvoiceError(records)) + } + } + } + + fun rootHash(tlvStream: TlvStream): ByteVector32 { + val encodedTlvs = (tlvStream.records + tlvStream.unknown).sortedBy { it.tag }.map { tlv -> + val out = ByteArrayOutput() + LightningCodecs.writeBigSize(tlv.tag, out) + val tag = out.toByteArray() + val data = tlv.write() + LightningCodecs.writeBigSize(data.size.toLong(), out) + LightningCodecs.writeBytes(data, out) + Pair(tag, out.toByteArray()) + } + val nonceKey = "LnNonce".encodeToByteArray() + encodedTlvs[0].second + + fun previousPowerOfTwo(n: Int): Int { + var p = 1 + while (p < n) { + p = p shl 1 + } + return p shr 1 + } + + fun merkleTree(i: Int, j: Int): ByteArray { + val (a, b) = if (j - i == 1) { + val (tag, fullTlv) = encodedTlvs[i] + Pair(hash("LnLeaf".encodeToByteArray(), fullTlv), hash(nonceKey, tag)) + } else { + val k = i + previousPowerOfTwo(j - i) + Pair(merkleTree(i, k), merkleTree(k, j)) + } + return if (LexicographicalOrdering.isLessThan(a, b)) { + hash("LnBranch".encodeToByteArray(), a + b) + } else { + hash("LnBranch".encodeToByteArray(), b + a) + } + } + + return ByteVector32(merkleTree(0, encodedTlvs.size)) + } + + private fun hash(tag: ByteArray, msg: ByteArray): ByteArray { + val tagHash = Crypto.sha256(tag) + return Crypto.sha256(tagHash + tagHash + msg) + } + + fun signSchnorr(tag: ByteVector, msg: ByteVector32, key: PrivateKey): ByteVector64 { + val h = ByteVector32(hash(tag.toByteArray(), msg.toByteArray())) + // NB: we don't add auxiliary random data to keep signatures deterministic. + return Crypto.signSchnorr(h, key, Crypto.SchnorrTweak.NoTweak) + } + + fun verifySchnorr(tag: ByteVector, msg: ByteVector32, signature: ByteVector64, publicKey: PublicKey): Boolean { + val h = ByteVector32(hash(tag.toByteArray(), msg.toByteArray())) + return Crypto.verifySignatureSchnorr(h, signature, XonlyPublicKey(publicKey)) + } + + /** We often need to remove the signature field to compute the merkle root. */ + fun removeSignature(records: TlvStream): TlvStream = + TlvStream(records.records.filter { it !is Signature }.toSet(), records.unknown) + +} diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt new file mode 100644 index 000000000..386ecdd7a --- /dev/null +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/OfferTypesTestsCommon.kt @@ -0,0 +1,316 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.io.ByteArrayInput +import fr.acinq.bitcoin.io.ByteArrayOutput +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output +import fr.acinq.lightning.* +import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.Lightning.randomKey +import fr.acinq.lightning.crypto.RouteBlinding +import fr.acinq.lightning.tests.utils.LightningTestSuite +import fr.acinq.lightning.utils.msat +import fr.acinq.lightning.utils.toByteVector +import fr.acinq.lightning.wire.OfferTypes.ContactInfo.BlindedPath +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequest +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequestAmount +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequestChain +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequestMetadata +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequestPayerId +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequestQuantity +import fr.acinq.lightning.wire.OfferTypes.InvoiceRequestTlv +import fr.acinq.lightning.wire.OfferTypes.Offer +import fr.acinq.lightning.wire.OfferTypes.OfferAmount +import fr.acinq.lightning.wire.OfferTypes.OfferChains +import fr.acinq.lightning.wire.OfferTypes.OfferDescription +import fr.acinq.lightning.wire.OfferTypes.OfferIssuer +import fr.acinq.lightning.wire.OfferTypes.OfferNodeId +import fr.acinq.lightning.wire.OfferTypes.OfferQuantityMax +import fr.acinq.lightning.wire.OfferTypes.Signature +import fr.acinq.lightning.wire.OfferTypes.readPath +import fr.acinq.lightning.wire.OfferTypes.removeSignature +import fr.acinq.lightning.wire.OfferTypes.rootHash +import fr.acinq.lightning.wire.OfferTypes.signSchnorr +import fr.acinq.lightning.wire.OfferTypes.writePath +import kotlin.test.* + +class OfferTypesTestsCommon : LightningTestSuite() { + val nodeKey = PrivateKey.fromHex("85d08273493e489b9330c85a3e54123874c8cd67c1bf531f4b926c9c555f8e1d") + val nodeId = nodeKey.publicKey() + + @Test + fun `invoice request is signed`() { + val sellerKey = randomKey() + val offer = Offer(100_000.msat, "test offer", sellerKey.publicKey(), Features.empty, Block.LivenetGenesisBlock.hash) + val payerKey = randomKey() + val request = InvoiceRequest(offer, 100_000.msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertTrue(request.checkSignature()) + } + + @Test + fun `minimal offer`() { + val tlvs = setOf( + OfferDescription("basic offer"), + OfferNodeId(nodeId)) + val offer = Offer(TlvStream(tlvs)) + val encoded = "lno1pg9kyctnd93jqmmxvejhy93pqvxl9c6mjgkeaxa6a0vtxqteql688v0ywa8qqwx4j05cyskn8ncrj" + assertEquals(offer, Offer.decode(encoded).get()) + assertNull(offer.amount) + assertEquals("basic offer", offer.description) + assertEquals(nodeId, offer.nodeId) + // Removing any TLV from the minimal offer makes it invalid. + for (tlv in tlvs) { + val incomplete = TlvStream(tlvs.filterNot{it == tlv}.toSet()) + assertTrue(Offer.validate(incomplete).isLeft) + val incompleteEncoded = Bech32.encodeBytes(Offer.hrp, Offer.tlvSerializer.write(incomplete), Bech32.Encoding.Beck32WithoutChecksum) + assertTrue(Offer.decode(incompleteEncoded).isFailure) + } + } + + @Test + fun `offer with amount and quantity`() { + val offer = Offer(TlvStream( + OfferChains(listOf(Block.TestnetGenesisBlock.hash)), + OfferAmount(50.msat), + OfferDescription("offer with quantity"), + OfferIssuer("alice@bigshop.com"), + OfferQuantityMax(0), + OfferNodeId(nodeId))) + val encoded = "lno1qgsyxjtl6luzd9t3pr62xr7eemp6awnejusgf6gw45q75vcfqqqqqqqgqyeq5ym0venx2u3qwa5hg6pqw96kzmn5d968jys3v9kxjcm9gp3xjemndphhqtnrdak3gqqkyypsmuhrtwfzm85mht4a3vcp0yrlgua3u3m5uqpc6kf7nqjz6v70qwg" + assertEquals(offer, Offer.decode(encoded).get()) + assertEquals(50.msat, offer.amount) + assertEquals("offer with quantity", offer.description) + assertEquals( nodeId, offer.nodeId) + assertEquals("alice@bigshop.com", offer.issuer) + assertEquals(Long.MAX_VALUE, offer.quantityMax) + } + + fun signInvoiceRequest(request: InvoiceRequest, key: PrivateKey): InvoiceRequest { + val tlvs = removeSignature(request.records) + val signature = signSchnorr(InvoiceRequest.signatureTag, rootHash(tlvs), key) + val signedRequest = InvoiceRequest(tlvs.copy(records = tlvs.records + Signature(signature))) + assertTrue(signedRequest.checkSignature()) + return signedRequest + } + + @Test + fun `check that invoice request matches offer`() { + val offer = Offer(2500.msat, "basic offer", randomKey().publicKey(), Features.empty, Block.LivenetGenesisBlock.hash) + val payerKey = randomKey() + val request = InvoiceRequest(offer, 2500.msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertTrue(request.isValid()) + assertEquals(offer, request.offer) + val biggerAmount = signInvoiceRequest(request.copy(records = TlvStream(request.records.records.map { when(it) { is InvoiceRequestAmount -> InvoiceRequestAmount(3000.msat) else -> it }}.toSet())), payerKey) + assertTrue(biggerAmount.isValid()) + assertEquals(offer, biggerAmount.offer) + val lowerAmount = signInvoiceRequest(request.copy(records = TlvStream(request.records.records.map { when(it) { is InvoiceRequestAmount -> InvoiceRequestAmount(2000.msat) else -> it }}.toSet())), payerKey) + assertFalse(lowerAmount.isValid()) + val withQuantity = signInvoiceRequest(request.copy(records = TlvStream(request.records.records + InvoiceRequestQuantity(1))), payerKey) + assertFalse(withQuantity.isValid()) + } + + @Test + fun `check that invoice request matches offer - with features`() { + val offer = Offer(2500.msat, "offer with features", randomKey().publicKey(), Features.empty, Block.LivenetGenesisBlock.hash) + val payerKey = randomKey() + val request = InvoiceRequest(offer, 2500.msat, 1, Features(Feature.BasicMultiPartPayment to FeatureSupport.Optional), payerKey, Block.LivenetGenesisBlock.hash) + assertTrue(request.isValid()) + assertEquals(offer, request.offer) + val withoutFeatures = InvoiceRequest(offer, 2500.msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertTrue(withoutFeatures.isValid()) + assertEquals(offer, withoutFeatures.offer) + val otherFeatures = InvoiceRequest(offer, 2500.msat, 1, Features(Feature.BasicMultiPartPayment to FeatureSupport.Mandatory), payerKey, Block.LivenetGenesisBlock.hash) + assertFalse(otherFeatures.isValid()) + assertEquals(offer, otherFeatures.offer) + } + + @Test + fun `check that invoice request matches offer - without amount`() { + val offer = Offer(null, "offer without amount", randomKey().publicKey(), Features.empty, Block.LivenetGenesisBlock.hash) + val payerKey = randomKey() + val request = InvoiceRequest(offer, 500.msat, 1, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertTrue(request.isValid()) + assertEquals(offer, request.offer) + val withoutAmount = signInvoiceRequest(request.copy(records = TlvStream(request.records.records.filterNot { it is InvoiceRequestAmount }.toSet())), payerKey) + assertFalse(withoutAmount.isValid()) + } + + @Test + fun `check that invoice request matches offer - without chain`() { + val offer = Offer(TlvStream(OfferAmount(100.msat), OfferDescription("offer without chains"), OfferNodeId(randomKey().publicKey()))) + val payerKey = randomKey() + val tlvs: Set = offer.records.records + setOf( + InvoiceRequestMetadata(ByteVector.fromHex("012345")), + InvoiceRequestAmount(100.msat), + InvoiceRequestPayerId(payerKey.publicKey()), + ) + val signature = signSchnorr(InvoiceRequest.signatureTag, rootHash(TlvStream(tlvs)), payerKey) + val request = InvoiceRequest(TlvStream(tlvs + Signature(signature))) + assertTrue(request.isValid()) + assertEquals(offer, request.offer) + val withDefaultChain = signInvoiceRequest(request.copy(records = TlvStream(request.records.records + InvoiceRequestChain(Block.LivenetGenesisBlock.hash))), payerKey) + assertTrue(withDefaultChain.isValid()) + assertEquals(offer, withDefaultChain.offer) + val otherChain = signInvoiceRequest(request.copy(records = TlvStream(request.records.records + InvoiceRequestChain(Block.TestnetGenesisBlock.hash))), payerKey) + assertFalse(otherChain.isValid()) + } + + @Test + fun `check that invoice request matches offer - with chains`() { + val chain1 = BlockHash(randomBytes32()) + val chain2 = BlockHash(randomBytes32()) + val offer = Offer(TlvStream(OfferChains(listOf(chain1, chain2)), OfferAmount(100.msat), OfferDescription("offer with chains"), OfferNodeId(randomKey().publicKey()))) + val payerKey = randomKey() + val request1 = InvoiceRequest(offer, 100.msat, 1, Features.empty, payerKey, chain1) + assertTrue(request1.isValid()) + assertEquals(offer, request1.offer) + val request2 = InvoiceRequest(offer, 100.msat, 1, Features.empty, payerKey, chain2) + assertTrue(request2.isValid()) + assertEquals(offer, request2.offer) + val noChain = signInvoiceRequest(request1.copy(records = TlvStream(request1.records.records.filterNot { it is InvoiceRequestChain }.toSet())), payerKey) + assertFalse(noChain.isValid()) + val otherChain = signInvoiceRequest(request1.copy(records = TlvStream(request1.records.records.map { when(it){ is InvoiceRequestChain -> InvoiceRequestChain(Block.LivenetGenesisBlock.hash) else -> it }}.toSet())), payerKey) + assertFalse(otherChain.isValid()) + } + + @Test + fun `check that invoice request matches offer - multiple items`() { + val offer = Offer(TlvStream( + OfferAmount(500.msat), + OfferDescription("offer for multiple items"), + OfferNodeId(randomKey().publicKey()), + OfferQuantityMax(10), + )) + val payerKey = randomKey() + val request = InvoiceRequest(offer, 1600.msat, 3, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertNotNull(request.records.get()) + assertTrue(request.isValid()) + assertEquals(offer, request.offer) + val invalidAmount = InvoiceRequest(offer, 2400.msat, 5, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertFalse(invalidAmount.isValid()) + val tooManyItems = InvoiceRequest(offer, 5500.msat, 11, Features.empty, payerKey, Block.LivenetGenesisBlock.hash) + assertFalse(tooManyItems.isValid()) + } + + @Test + fun `minimal invoice request`() { + val payerKey = PrivateKey.fromHex("527d410ec920b626ece685e8af9abc976a48dbf2fe698c1b35d90a1c5fa2fbca") + val tlvsWithoutSignature = setOf( + InvoiceRequestMetadata(ByteVector.fromHex("abcdef")), + OfferDescription("basic offer"), + OfferNodeId(nodeId), + InvoiceRequestPayerId(payerKey.publicKey()), + ) + val signature = signSchnorr(InvoiceRequest.signatureTag, rootHash(TlvStream(tlvsWithoutSignature)), payerKey) + val tlvs = tlvsWithoutSignature + Signature(signature) + val invoiceRequest = InvoiceRequest(TlvStream(tlvs)) + val encoded = "lnr1qqp6hn00pg9kyctnd93jqmmxvejhy93pqvxl9c6mjgkeaxa6a0vtxqteql688v0ywa8qqwx4j05cyskn8ncrjkppqfxajawru7sa7rt300hfzs2lyk2jrxduxrkx9lmzy6lxcvfhk0j7ruzqc4mtjj5fwukrqp7faqrxn664nmwykad76pu997terewcklsx47apag59wf8exly4tky7y63prr7450n28stqssmzuf48w7e6rjad2eq" + assertEquals(invoiceRequest, InvoiceRequest.decode(encoded).get()) + assertNull(invoiceRequest.offer.amount) + assertEquals("basic offer", invoiceRequest.offer.description) + assertEquals(nodeId, invoiceRequest.offer.nodeId) + assertEquals(ByteVector.fromHex("abcdef"), invoiceRequest.metadata) + assertEquals(payerKey.publicKey(), invoiceRequest.payerId) + // Removing any TLV from the minimal invoice request makes it invalid. + for (tlv in tlvs) { + val incomplete = TlvStream(tlvs.filterNot{it == tlv}.toSet()) + assertTrue(InvoiceRequest.validate(incomplete).isLeft) + val incompleteEncoded = Bech32.encodeBytes(InvoiceRequest.hrp, InvoiceRequest.tlvSerializer.write(incomplete), Bech32.Encoding.Beck32WithoutChecksum) + assertTrue(InvoiceRequest.decode(incompleteEncoded).isFailure) + } + } + + @Test + fun `compute merkle tree root`() { + data class TestCase(val tlvs: String, val count: Int, val expected: ByteVector32) + + data class GenericTlv(val data: ByteVector, override val tag: Long) : Tlv { + override fun write(out: Output) { + LightningCodecs.writeBytes(data, out) + } + } + + data class GenericTlvReader(val tag: Long) : TlvValueReader { + override fun read(input: Input): GenericTlv { + return GenericTlv(LightningCodecs.bytes(input, input.availableBytes).toByteVector(), tag) + } + } + + val genericTlvSerializer = TlvStreamSerializer( + false, (0..1000).map { i -> i.toLong() to GenericTlvReader(i.toLong()) }.toMap() + ) + + val testCases = listOf( + // Official test vectors. + TestCase("010203e8", 1, ByteVector32.fromValidHex("b013756c8fee86503a0b4abdab4cddeb1af5d344ca6fc2fa8b6c08938caa6f93")), + TestCase("010203e8 02080000010000020003", 2, ByteVector32.fromValidHex("c3774abbf4815aa54ccaa026bff6581f01f3be5fe814c620a252534f434bc0d1")), + TestCase("010203e8 02080000010000020003 03310266e4598d1d3c415f572a8488830b60f7e744ed9235eb0b1ba93283b315c0351800000000000000010000000000000002", 3, ByteVector32.fromValidHex("ab2e79b1283b0b31e0b035258de23782df6b89a38cfa7237bde69aed1a658c5d")), + TestCase("0008000000000000000006035553440801640a1741204d617468656d61746963616c205472656174697365162102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f28368661958210324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c", 6, ByteVector32.fromValidHex("608407c18ad9a94d9ea2bcdbe170b6c20c462a7833a197621c916f78cf18e624")), + // Additional test vectors. + TestCase("010100", 1, ByteVector32.fromValidHex("14ffa5e1e5d861059abff167dad6e632c45483006f7d4dc4355586062a3da30d")), + TestCase("010100 020100", 2, ByteVector32.fromValidHex("ec0584e764b71cb49ebe60ce7edbab8387e42da20b6077031bd27ff345b38ff8")), + TestCase("010100 020100 030100", 3, ByteVector32.fromValidHex("cc68aea3dc863832ef6828b3da8689cce3478c934cc50a68522477506a35feb2")), + TestCase("010100 020100 030100 040100", 4, ByteVector32.fromValidHex("b531eaa1ca71956148a6756cf8f46bdf231879e6c392019877f23e56acb7b956")), + TestCase("010100 020100 030100 040100 050100", 5, ByteVector32.fromValidHex("104e383bfdcb620cd8cefa95245332e8bd32ffd8d974fffdafe1488b1f4a1fbd")), + TestCase("010100 020100 030100 040100 050100 060100", 6, ByteVector32.fromValidHex("d96f0769702cb3440abbe683d7211fd20bd152699352f09f45d2695a89d18cdc")), + TestCase("010100 020100 030100 040100 050100 060100 070100", 7, ByteVector32.fromValidHex("30b8886e306c97dbc7b730a2e99138c1ea4fdf5c2f71e2a31e434f63f5eed228")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100", 8, ByteVector32.fromValidHex("783262efe5eeef4ec96bcee8d7cf5149ea44e0c28a78f4b1cb73d6cec9a0b378")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100", 9, ByteVector32.fromValidHex("6fd20b65a0097aff2bcc70753612a296edc27933ea335bac5df2e4c724cdb43c")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100", 10, ByteVector32.fromValidHex("9a3cf7785e9c84e03d6bc7fc04226a1cb19f158a69f16684663aa710bd90a14b")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100", 11, ByteVector32.fromValidHex("ace50a04d9dc82ce123c6ac6c2449fa607054560a9a7b8229cd2d47c01b94953")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100", 12, ByteVector32.fromValidHex("1a8e85042447a10ec312b35db34d0c8722caba4aaf6a170c4506d1fdb520aa66")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100", 13, ByteVector32.fromValidHex("8c3b8d9ba90eb9a4a34c890a7a24ba6ddc873529c5fd7c95f33a5b9ba589f54b")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100", 14, ByteVector32.fromValidHex("ed9e3694bbad2fca636576cc69af4c63ad64023bfeb788fe0f40b3533b248a6a")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100", 15, ByteVector32.fromValidHex("bab201e05786ae1eae4d685b4f815134158720ba297ea0f46a9420ffe5e94b16")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100", 16, ByteVector32.fromValidHex("44438261bb64672f374d8782e92dc9616e900378ce4bd64442753722bc2a1acb")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100", 17, ByteVector32.fromValidHex("bb6fbcd5cf426ec0b7e49d9f9ccc6c15319e01f007cce8f16fa802016718b9f7")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100", 18, ByteVector32.fromValidHex("64d8639e76af096223cad2c448d68fabf751d1c6a939bc86e1015b19188202dc")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100", 19, ByteVector32.fromValidHex("bcb88f8e06886a6d422d14bc2ed4e7fc06c0ad2adeedf630a73972c5b15538ca")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100", 20, ByteVector32.fromValidHex("9deddd5f0ab909e6a161fd4b9d44ed7384ee0a7fe8d3fbb637872767eab82f1e")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100", 21, ByteVector32.fromValidHex("4a32a2325bbd1c2b5b4915c6bec6b3e3d734d956e0c123f1fa6d70f7a8609dcd")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100", 22, ByteVector32.fromValidHex("a3ec28f0f9cb64db8d96dd7b9039fbf2240438401ea992df802d7bb70b3d02af")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100", 23, ByteVector32.fromValidHex("d025f268ec4f09baf51c4b94287e76707d9353e8cab31dc586ae47742ba0b266")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100", 24, ByteVector32.fromValidHex("cd5a2086a3919d67d0617da1e6e293f115bed8d8306498ed814c6c109ad370a4")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100", 25, ByteVector32.fromValidHex("f64113810b52f4d6a55380a3d84e59e34d26c145448121c2113a023cb63de71b")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100", 26, ByteVector32.fromValidHex("b99d7332ea2db048093a7bc0aaa85f82ccfa9da2b734fc0a14b79c5dac5a3a1c")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100 1b0100", 27, ByteVector32.fromValidHex("fab01a3ce6e878942dc5c9c862cb18e88202d50e6026d2266748f7eda5f9db7f")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100 1b0100 1c0100", 28, ByteVector32.fromValidHex("2dc8b24a0e142d1ed36a144ed35ef0d4b7d0d1b51e198b2282248e45ebaf0417")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100 1b0100 1c0100 1d0100", 29, ByteVector32.fromValidHex("3693a858cc97762d69d05b2191d3e5254c29ddb5abac5b9fe52b227fa216aa4c")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100 1b0100 1c0100 1d0100 1e0100", 30, ByteVector32.fromValidHex("db8787d4509265e764e60b7a81cf38efb9d3a7910d67c4ae68a1232436e1cd3b")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100 1b0100 1c0100 1d0100 1e0100 1f0100", 31, ByteVector32.fromValidHex("af49f35e5b2565cb229f342405783d330c56031f005a4a6ca01f87e5637d4614")), + TestCase("010100 020100 030100 040100 050100 060100 070100 080100 090100 0a0100 0b0100 0c0100 0d0100 0e0100 0f0100 100100 110100 120100 130100 140100 150100 160100 170100 180100 190100 1a0100 1b0100 1c0100 1d0100 1e0100 1f0100 200100", 32, ByteVector32.fromValidHex("2e9f8a8542576197650f61c882625f0f6838f962f9fa24ce809b687784a8a7de")), + ) + testCases.forEach { + (tlvStream, tlvCount, expectedRoot) -> + val tlvs = genericTlvSerializer.read(ByteVector.fromHex(tlvStream).toByteArray()) + assertEquals(tlvCount, tlvs.records.size) + val root = rootHash(tlvs) + assertEquals(expectedRoot, root) + } + } + + @Test + fun `compact blinded route`() { + data class TestCase(val encoded: ByteVector, val decoded: BlindedPath) + + val testCases = listOf( + TestCase(ByteVector.fromHex("00 00000000000004d2 0379b470d00b78ded936f8972a0f3ecda2bb6e6df40dcd581dbaeb3742b30008ff 01 02fba71b72623187dd24670110eec870e28b848f255ba2edc0486d3a8e89ec44b7 0002 1dea"), + BlindedPath(RouteBlinding.BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1 = true, ShortChannelId(1234)), PublicKey.fromHex("0379b470d00b78ded936f8972a0f3ecda2bb6e6df40dcd581dbaeb3742b30008ff"), listOf(RouteBlinding.BlindedNode(PublicKey.fromHex("02fba71b72623187dd24670110eec870e28b848f255ba2edc0486d3a8e89ec44b7"), ByteVector.fromHex("1dea")))))), + TestCase(ByteVector.fromHex("01 000000000000ddd5 0353a081bb02d6e361be3df3e92b41b788ca65667f6ea0c01e2bfa03664460ef86 01 03bce3f0cdb4172caac82ec8a9251eb35df1201bdcb977c5a03f3624ec4156a65f 0003 c0ffee"), + BlindedPath(RouteBlinding.BlindedRoute(EncodedNodeId.ShortChannelIdDir(isNode1 = false, ShortChannelId(56789)), PublicKey.fromHex("0353a081bb02d6e361be3df3e92b41b788ca65667f6ea0c01e2bfa03664460ef86"), listOf(RouteBlinding.BlindedNode(PublicKey.fromHex("03bce3f0cdb4172caac82ec8a9251eb35df1201bdcb977c5a03f3624ec4156a65f"), ByteVector.fromHex("c0ffee")))))), + TestCase(ByteVector.fromHex("022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73 0379a3b6e4bceb7519d09db776994b1f82cf6a9fa4d3ec2e52314c5938f2f9f966 01 02b446aaa523df82a992ab468e5298eabb6168e2c466455c210d8c97dbb8981328 0002 cafe"), + BlindedPath(RouteBlinding.BlindedRoute(EncodedNodeId.Plain(PublicKey.fromHex("022d3b15cea00ee4a8e710b082bef18f0f3409cc4e7aff41c26eb0a4d3ab20dd73")), PublicKey.fromHex("0379a3b6e4bceb7519d09db776994b1f82cf6a9fa4d3ec2e52314c5938f2f9f966"), listOf(RouteBlinding.BlindedNode(PublicKey.fromHex("02b446aaa523df82a992ab468e5298eabb6168e2c466455c210d8c97dbb8981328"), ByteVector.fromHex("cafe")))))), + TestCase(ByteVector.fromHex("03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922 028aa5d1a10463d598a0a0ab7296af21619049f94fe03ef664a87561009e58c3dd 01 02988d7381d0434cfebbe521031505fb9987ae6cefd0bab0e5927852eb96bb6cc2 0003 ec1a13"), + BlindedPath(RouteBlinding.BlindedRoute(EncodedNodeId.Plain(PublicKey.fromHex("03ba3c458e3299eb19d2e07ae86453f4290bcdf8689707f0862f35194397c45922")), PublicKey.fromHex("028aa5d1a10463d598a0a0ab7296af21619049f94fe03ef664a87561009e58c3dd"), listOf(RouteBlinding.BlindedNode(PublicKey.fromHex("02988d7381d0434cfebbe521031505fb9987ae6cefd0bab0e5927852eb96bb6cc2"), ByteVector.fromHex("ec1a13")))))), + ) + + testCases.forEach { + (encoded, decoded) -> + val out = ByteArrayOutput() + writePath(decoded, out) + assertEquals(encoded, out.toByteArray().toByteVector()) + assertEquals(decoded, readPath(ByteArrayInput(encoded.toByteArray()))) + } + } +} \ No newline at end of file