diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 4dbfa52237..4e779de128 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2605,40 +2605,6 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // forwarded. availableBandwidth := l.Bandwidth() - auxBandwidth, externalErr := fn.MapOptionZ( - l.cfg.AuxTrafficShaper, - func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { - var htlcBlob fn.Option[tlv.Blob] - blob, err := customRecords.Serialize() - if err != nil { - return fn.Err[OptionalBandwidth]( - fmt.Errorf("unable to serialize "+ - "custom records: %w", err)) - } - - if len(blob) > 0 { - htlcBlob = fn.Some(blob) - } - - return l.AuxBandwidth(amt, originalScid, htlcBlob, ts) - }, - ).Unpack() - if externalErr != nil { - l.log.Errorf("Unable to determine aux bandwidth: %v", - externalErr) - - return NewLinkError(&lnwire.FailTemporaryNodeFailure{}) - } - - if auxBandwidth.IsHandled && auxBandwidth.Bandwidth.IsSome() { - auxBandwidth.Bandwidth.WhenSome( - func(bandwidth lnwire.MilliSatoshi) { - availableBandwidth = bandwidth - }, - ) - } - - // Check to see if there is enough balance in this channel. if amt > availableBandwidth { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, availableBandwidth) diff --git a/lnwallet/aux_test_utils.go b/lnwallet/aux_test_utils.go new file mode 100644 index 0000000000..eb140644ab --- /dev/null +++ b/lnwallet/aux_test_utils.go @@ -0,0 +1,39 @@ +package lnwallet + +import ( + "github.com/lightningnetwork/lnd/lnwire" +) + +// NewTestAuxHtlcDescriptor creates an AuxHtlcDescriptor for testing purposes. +// This function allows tests to create descriptors with specific commit heights +// and entry types, which are normally unexported fields. +func NewTestAuxHtlcDescriptor( + chanID lnwire.ChannelID, + rHash PaymentHash, + timeout uint32, + amount lnwire.MilliSatoshi, + htlcIndex uint64, + parentIndex uint64, + entryType uint8, + customRecords lnwire.CustomRecords, + addHeightLocal uint64, + addHeightRemote uint64, + removeHeightLocal uint64, + removeHeightRemote uint64, +) AuxHtlcDescriptor { + + return AuxHtlcDescriptor{ + ChanID: chanID, + RHash: rHash, + Timeout: timeout, + Amount: amount, + HtlcIndex: htlcIndex, + ParentIndex: parentIndex, + EntryType: updateType(entryType), + CustomRecords: customRecords, + addCommitHeightLocal: addHeightLocal, + addCommitHeightRemote: addHeightRemote, + removeCommitHeightLocal: removeHeightLocal, + removeCommitHeightRemote: removeHeightRemote, + } +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 484a019da5..4e072a5b5b 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -833,6 +833,14 @@ type LightningChannel struct { // is created. type ChannelOpt func(*channelOpts) +// AuxHtlcValidator is a function that validates whether an HTLC can be added +// to a custom channel. It is called during HTLC validation with the current +// channel state and HTLC details. This allows external components (like the +// traffic shaper) to perform final validation checks against the most +// up-to-date channel state before the HTLC is committed. +type AuxHtlcValidator func(amount, linkBandwidth lnwire.MilliSatoshi, + customRecords lnwire.CustomRecords, view AuxHtlcView) error + // channelOpts is the set of options used to create a new channel. type channelOpts struct { localNonce *musig2.Nonces @@ -842,6 +850,10 @@ type channelOpts struct { auxSigner fn.Option[AuxSigner] auxResolver fn.Option[AuxContractResolver] + // auxHtlcValidator is an optional validator that performs custom + // validation on HTLCs before they are added to the channel state. + auxHtlcValidator fn.Option[AuxHtlcValidator] + skipNonceInit bool } @@ -894,6 +906,15 @@ func WithAuxResolver(resolver AuxContractResolver) ChannelOpt { } } +// WithAuxHtlcValidator is used to specify a custom HTLC validator for the +// channel. This validator will be called during HTLC addition to perform +// final validation checks against the most up-to-date channel state. +func WithAuxHtlcValidator(validator AuxHtlcValidator) ChannelOpt { + return func(o *channelOpts) { + o.auxHtlcValidator = fn.Some(validator) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -2738,9 +2759,15 @@ func (lc *LightningChannel) FetchLatestAuxHTLCView() AuxHtlcView { lc.RLock() defer lc.RUnlock() - return newAuxHtlcView(lc.fetchHTLCView( - lc.updateLogs.Remote.logIndex, lc.updateLogs.Local.logIndex, - )) + nextHeight := lc.commitChains.Local.tip().height + 1 + remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote + view := lc.fetchHTLCView( + remoteACKedIndex, lc.updateLogs.Local.logIndex, + ) + + view.NextHeight = nextHeight + + return newAuxHtlcView(view) } // fetchHTLCView returns all the candidate HTLC updates which should be @@ -6065,6 +6092,52 @@ func (lc *LightningChannel) addHTLC(htlc *lnwire.UpdateAddHTLC, return 0, err } + // If an auxiliary HTLC validator is configured, call it now to perform + // custom validation checks against the current channel state. This is + // the final validation point before the HTLC is added to the update + // log, ensuring that the validator sees the most up-to-date state + // including all previously validated HTLCs in this batch. + // + // NOTE: This is called after the standard commitment sanity checks to + // ensure we only perform (potentially) expensive custom validation on + // HTLCs that have already passed the basic Lightning protocol + // constraints. + err := fn.MapOptionZ( + lc.opts.auxHtlcValidator, + func(validator AuxHtlcValidator) error { + // Fetch the current HTLC view which includes all + // pending HTLCs that haven't been committed yet. This + // provides the validator with the most accurate state. + commitChain := lc.commitChains.Local + remoteIndex := commitChain.tail().messageIndices.Remote + view := lc.fetchHTLCView( + remoteIndex, + lc.updateLogs.Local.logIndex, + ) + + nextHeight := lc.commitChains.Local.tip().height + 1 + view.NextHeight = nextHeight + + lc.log.Infof("Setting view nextheight=%v", nextHeight) + + auxView := newAuxHtlcView(view) + + // Get the current available balance for the link + // bandwidth check. This is needed for the balance + // validation in the traffic shaper. We use NoBuffer + // since this is the final check before adding the HTLC. + linkBandwidth, _ := lc.availableBalance(NoBuffer) + + return validator( + pd.Amount, linkBandwidth, pd.CustomRecords, + auxView, + ) + }, + ) + if err != nil { + return 0, fmt.Errorf("aux HTLC validation failed: %w", err) + } + lc.updateLogs.Local.appendHtlc(pd) return pd.HtlcIndex, nil @@ -6216,6 +6289,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, // remote commitments. func (lc *LightningChannel) validateAddHtlc(pd *paymentDescriptor, buffer BufferType) error { + // Make sure adding this HTLC won't violate any of the constraints we // must keep on the commitment transactions. remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote diff --git a/peer/brontide.go b/peer/brontide.go index 560e8d121f..77961bd5ee 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -52,6 +52,7 @@ import ( "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/protofsm" "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/tlv" @@ -1140,6 +1141,16 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( }, ) + p.cfg.AuxTrafficShaper.WhenSome( + func(ts htlcswitch.AuxTrafficShaper) { + val := p.createHtlcValidator(dbChan, ts) + chanOpts = append( + chanOpts, + lnwallet.WithAuxHtlcValidator(val), + ) + }, + ) + lnChan, err := lnwallet.NewLightningChannel( p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) @@ -5229,6 +5240,15 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) }) + p.cfg.AuxTrafficShaper.WhenSome( + func(ts htlcswitch.AuxTrafficShaper) { + val := p.createHtlcValidator(c.OpenChannel, ts) + chanOpts = append( + chanOpts, lnwallet.WithAuxHtlcValidator(val), + ) + }, + ) + // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel // ID. @@ -5435,6 +5455,66 @@ func (p *Brontide) scaleTimeout(timeout time.Duration) time.Duration { return timeout } +// createHtlcValidator creates an HTLC validator function that performs final +// aux balance validation before HTLCs are added to the channel state. This +// validator calls into the traffic shaper's PaymentBandwidth method to check +// external balance against the most up-to-date channel state, preventing race +// conditions where multiple HTLCs could be approved based on stale bandwidth. +func (p *Brontide) createHtlcValidator(dbChan *channeldb.OpenChannel, + ts htlcswitch.AuxTrafficShaper) lnwallet.AuxHtlcValidator { + + return func(amount, linkBandwidth lnwire.MilliSatoshi, + customRecords lnwire.CustomRecords, + view lnwallet.AuxHtlcView) error { + + // Get the short channel ID for logging. + scid := dbChan.ShortChannelID + + // Extract the HTLC custom records to pass to the traffic + // shaper. + var htlcBlob fn.Option[tlv.Blob] + if len(customRecords) > 0 { + blob, err := customRecords.Serialize() + if err != nil { + return fmt.Errorf("unable to serialize "+ + "custom records: %w", err) + } + htlcBlob = fn.Some(blob) + } + + // Get the funding and commitment blobs for this channel. + fundingBlob := dbChan.CustomBlob + commitmentBlob := dbChan.LocalCommitment.CustomBlob + + peer := route.NewVertex(p.IdentityKey()) + + // Call the traffic shaper's PaymentBandwidth method with the + // current state. This performs the same bandwidth checks as + // during pathfinding/forwarding, but against the absolute + // latest channel state. + // + // The linkBandwidth is provided by the channel and represents + // the current available balance, which is used by the traffic + // shaper to ensure we don't dip below channel reserves. + bandwidth, err := ts.PaymentBandwidth( + fundingBlob, htlcBlob, commitmentBlob, + linkBandwidth, amount, view, peer, + ) + if err != nil { + return fmt.Errorf("traffic shaper bandwidth check "+ + "failed: %w", err) + } + + if amount > bandwidth { + return fmt.Errorf("insufficient aux bandwidth: "+ + "need %v, have %v (scid=%v)", amount, + bandwidth, scid) + } + + return nil + } +} + // CoopCloseUpdates is a struct used to communicate updates for an active close // to the caller. type CoopCloseUpdates struct {