From 43b2c520decd1e3e2e5c448f11a01c708590003d Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Fri, 5 Jan 2024 11:19:18 +0100 Subject: [PATCH] pass chan updates on failure --- common/chan_update.go | 38 +++++++++++++++++++++++ common/intercept_handler.go | 27 ++++++++++++++++ interceptor/intercept_handler.go | 25 ++++++++++----- itest/cltv_test.go | 2 +- lsps2/intercept_handler.go | 53 +++++++++++++++++++++++++------- lsps2/intercept_test.go | 6 ++-- 6 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 common/chan_update.go diff --git a/common/chan_update.go b/common/chan_update.go new file mode 100644 index 00000000..19ebae6c --- /dev/null +++ b/common/chan_update.go @@ -0,0 +1,38 @@ +package common + +import ( + "bytes" + "time" + + "github.com/breez/lspd/lightning" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/lnwire" +) + +func ConstructChanUpdate( + chainhash chainhash.Hash, + node []byte, + destination []byte, + scid lightning.ShortChannelID, + timeLockDelta uint16, + htlcMinimumMsat, + htlcMaximumMsat uint64, +) lnwire.ChannelUpdate { + channelFlags := lnwire.ChanUpdateChanFlags(0) + if bytes.Compare(node, destination) > 0 { + channelFlags = 1 + } + + return lnwire.ChannelUpdate{ + ChainHash: chainhash, + ShortChannelID: scid.ToLnwire(), + Timestamp: uint32(time.Now().Unix()), + TimeLockDelta: timeLockDelta, + HtlcMinimumMsat: lnwire.MilliSatoshi(htlcMinimumMsat), + HtlcMaximumMsat: lnwire.MilliSatoshi(htlcMaximumMsat), + BaseFee: 0, + FeeRate: 0, + MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + ChannelFlags: channelFlags, + } +} diff --git a/common/intercept_handler.go b/common/intercept_handler.go index 9a2eea74..d5dbc899 100644 --- a/common/intercept_handler.go +++ b/common/intercept_handler.go @@ -1,10 +1,13 @@ package common import ( + "bytes" "fmt" + "log" "github.com/breez/lspd/lightning" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwire" ) type InterceptAction int @@ -27,6 +30,30 @@ var ( FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS InterceptFailureCode = []byte{0x40, 0x0F} ) +func FailureTemporaryChannelFailure(update *lnwire.ChannelUpdate) []byte { + var buf bytes.Buffer + msg := lnwire.NewTemporaryChannelFailure(update) + err := lnwire.EncodeFailureMessage(&buf, msg, 0) + if err != nil { + log.Printf("Failed to encode failure message for temporary channel failure: %v", err) + return FAILURE_TEMPORARY_CHANNEL_FAILURE + } + + return buf.Bytes() +} + +func FailureIncorrectCltvExpiry(cltvExpiry uint32, update lnwire.ChannelUpdate) []byte { + var buf bytes.Buffer + msg := lnwire.NewIncorrectCltvExpiry(cltvExpiry, update) + err := lnwire.EncodeFailureMessage(&buf, msg, 0) + if err != nil { + log.Printf("Failed to encode failure message for incorrect cltv expiry: %v", err) + return FAILURE_INCORRECT_CLTV_EXPIRY + } + + return buf.Bytes() +} + type InterceptRequest struct { // Identifier that uniquely identifies this htlc. // For cln, that's hash of the next onion or the shared secret. diff --git a/interceptor/intercept_handler.go b/interceptor/intercept_handler.go index 0b03db82..26ddb62e 100644 --- a/interceptor/intercept_handler.go +++ b/interceptor/intercept_handler.go @@ -111,7 +111,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("IsConnected(%x) error: %v", nextHop, err) return &common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureTemporaryChannelFailure(nil), }, nil } @@ -160,6 +160,17 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes }, nil } + // In case we fail with an error, this is the used channel update. + chanUpdate := common.ConstructChanUpdate( + i.node.ChainHash, + i.node.NodeId, + destination, + req.Scid, + uint16(i.node.NodeConfig.TimeLockDelta), + i.node.NodeConfig.MinHtlcMsat, + req.IncomingAmountMsat, + ) + // The first htlc of a MPP will open the channel. if channelPoint == nil { // TODO: When opening_fee_params is enforced, turn this check in a temporary channel failure. @@ -179,7 +190,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("paymentHash: %s, outgoingExpiry: %v, incomingExpiry: %v, i.node.NodeConfig.TimeLockDelta: %v", reqPaymentHashStr, req.OutgoingExpiry, req.IncomingExpiry, i.node.NodeConfig.TimeLockDelta) return common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureIncorrectCltvExpiry(req.IncomingExpiry, chanUpdate), }, nil } @@ -188,7 +199,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("paymentHash: %s, time.Parse(%s, %s) failed. Failing channel open: %v", reqPaymentHashStr, lsps0.TIME_FORMAT, params.ValidUntil, err) return common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } @@ -199,7 +210,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("Intercepted expired payment registration. Failing payment. payment hash: %s, valid until: %s", reqPaymentHashStr, params.ValidUntil) return common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } @@ -211,7 +222,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("paymentHash: %s, openChannel(%x, %v) err: %v", reqPaymentHashStr, destination, incomingAmountMsat, err) return common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } } @@ -238,7 +249,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("paymentHash: %s, insertChannel error: %v", reqPaymentHashStr, err) return common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil } @@ -271,7 +282,7 @@ func (i *Interceptor) Intercept(req common.InterceptRequest) common.InterceptRes log.Printf("paymentHash: %s, Error: Channel failed to open... timed out. ", reqPaymentHashStr) return common.InterceptResult{ Action: common.INTERCEPT_FAIL_HTLC_WITH_CODE, - FailureMessage: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + FailureMessage: common.FailureTemporaryChannelFailure(&chanUpdate), }, nil }) diff --git a/itest/cltv_test.go b/itest/cltv_test.go index bbe0f5c9..780e29f8 100644 --- a/itest/cltv_test.go +++ b/itest/cltv_test.go @@ -54,5 +54,5 @@ func testInvalidCltv(p *testParams) { // Decrement the delay in the first hop, so the cltv delta will become 143 (too little) route.Hops[0].Delay-- _, err := alice.PayViaRoute(outerAmountMsat, outerInvoice.paymentHash, outerInvoice.paymentSecret, route) - assert.Contains(p.t, err.Error(), "WIRE_TEMPORARY_CHANNEL_FAILURE") + assert.Contains(p.t, err.Error(), "WIRE_INCORRECT_CLTV_EXPIRY") } diff --git a/lsps2/intercept_handler.go b/lsps2/intercept_handler.go index 92ff53c8..c28d23ec 100644 --- a/lsps2/intercept_handler.go +++ b/lsps2/intercept_handler.go @@ -114,7 +114,7 @@ type paymentChanOpenedEvent struct { type paymentFailureEvent struct { paymentId string - code common.InterceptFailureCode + message common.InterceptFailureCode } func (i *Interceptor) Start(ctx context.Context) { @@ -135,7 +135,7 @@ func (i *Interceptor) Start(ctx context.Context) { case paymentId := <-i.paymentReady: i.handlePaymentReady(paymentId) case ev := <-i.paymentFailure: - i.handlePaymentFailure(ev.paymentId, ev.code) + i.handlePaymentFailure(ev.paymentId, ev.message) case ev := <-i.paymentChanOpened: i.handlePaymentChanOpened(ev) } @@ -189,7 +189,7 @@ func (i *Interceptor) handleNewPart(part *partState) { // a goroutine. i.paymentFailure <- &paymentFailureEvent{ paymentId: paymentId, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(nil), } case <-payment.timeoutChan: // Stop listening for timeouts when the payment is ready. @@ -284,7 +284,17 @@ func (i *Interceptor) processPart(payment *paymentState, part *partState) { // Make sure the cltv delta is enough (actual cltv delta + 2). if int64(part.req.IncomingExpiry)-int64(part.req.OutgoingExpiry) < int64(i.config.TimeLockDelta)+2 { - i.failPart(payment, part, common.FAILURE_INCORRECT_CLTV_EXPIRY) + peerid, _ := hex.DecodeString(payment.registration.PeerId) + chanUpdate := common.ConstructChanUpdate( + i.config.ChainHash, + i.config.NodeId, + peerid, + payment.fakeScid, + uint16(i.config.TimeLockDelta), + i.config.HtlcMinimumMsat, + payment.paymentSizeMsat, + ) + i.failPart(payment, part, common.FailureIncorrectCltvExpiry(part.req.IncomingExpiry, chanUpdate)) return } @@ -376,6 +386,16 @@ func (i *Interceptor) handlePaymentReady(paymentId string) { // a goroutine. func (i *Interceptor) ensureChannelOpen(payment *paymentState) { destination, _ := hex.DecodeString(payment.registration.PeerId) + peerid, _ := hex.DecodeString(payment.registration.PeerId) + chanUpdate := common.ConstructChanUpdate( + i.config.ChainHash, + i.config.NodeId, + peerid, + payment.fakeScid, + uint16(i.config.TimeLockDelta), + i.config.HtlcMinimumMsat, + payment.paymentSizeMsat, + ) if payment.registration.ChannelPoint == nil { @@ -392,7 +412,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_UNKNOWN_NEXT_PEER, + message: common.FAILURE_UNKNOWN_NEXT_PEER, } return } @@ -412,7 +432,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_UNKNOWN_NEXT_PEER, + message: common.FAILURE_UNKNOWN_NEXT_PEER, } return } @@ -468,7 +488,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { code := common.FAILURE_UNKNOWN_NEXT_PEER if strings.Contains(err.Error(), "not enough funds") { - code = common.FAILURE_TEMPORARY_CHANNEL_FAILURE + code = common.FailureTemporaryChannelFailure(&chanUpdate) } // TODO: Verify that a client disconnect before receiving @@ -479,7 +499,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { // temporary_channel_failure should be returned. i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: code, + message: code, } return } @@ -502,7 +522,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { ) i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(&chanUpdate), } return } @@ -524,7 +544,7 @@ func (i *Interceptor) ensureChannelOpen(payment *paymentState) { case <-time.After(time.Until(deadline)): i.paymentFailure <- &paymentFailureEvent{ paymentId: payment.id, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(&chanUpdate), } return } @@ -603,11 +623,22 @@ func (i *Interceptor) handlePaymentChanOpened(event *paymentChanOpenedEvent) { event.paymentId, feeRemainingMsat, ) + + peerid, _ := hex.DecodeString(payment.registration.PeerId) + chanUpdate := common.ConstructChanUpdate( + i.config.ChainHash, + i.config.NodeId, + peerid, + payment.fakeScid, + uint16(i.config.TimeLockDelta), + i.config.HtlcMinimumMsat, + payment.paymentSizeMsat, + ) // TODO: Verify temporary_channel_failure is the way to go here, maybe // unknown_next_peer is more appropriate. i.paymentFailure <- &paymentFailureEvent{ paymentId: event.paymentId, - code: common.FAILURE_TEMPORARY_CHANNEL_FAILURE, + message: common.FailureTemporaryChannelFailure(&chanUpdate), } return } diff --git a/lsps2/intercept_test.go b/lsps2/intercept_test.go index d55b2e67..f5a800cf 100644 --- a/lsps2/intercept_test.go +++ b/lsps2/intercept_test.go @@ -322,7 +322,7 @@ func Test_NoMpp_CltvDeltaBelowMinimum(t *testing.T) { res := i.Intercept(createPart(&part{cltvDelta: 145})) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.ElementsMatch(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureMessage) + assert.ElementsMatch(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureMessage[:2]) assertEmpty(t, i) } @@ -408,7 +408,7 @@ func Test_Mpp_SinglePart_AmtTooSmall(t *testing.T) { res := i.Intercept(createPart(&part{amt: defaultPaymentSizeMsat - 1})) end := time.Now() assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.ElementsMatch(t, common.FAILURE_TEMPORARY_CHANNEL_FAILURE, res.FailureMessage) + assert.ElementsMatch(t, common.FAILURE_TEMPORARY_CHANNEL_FAILURE, res.FailureMessage[:2]) assert.GreaterOrEqual(t, end.Sub(start).Milliseconds(), config.MppTimeout.Milliseconds()) assertEmpty(t, i) } @@ -541,7 +541,7 @@ func Test_Mpp_CltvDeltaBelowMinimum(t *testing.T) { res := i.Intercept(createPart(&part{cltvDelta: 145})) assert.Equal(t, common.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) - assert.ElementsMatch(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureMessage) + assert.ElementsMatch(t, common.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureMessage[:2]) assertEmpty(t, i) }