Skip to content

Commit

Permalink
Merge pull request #938 from ellemouton/sql2Accounts2
Browse files Browse the repository at this point in the history
[sql-2] accounts: start replacing calls to UpdateAccount
  • Loading branch information
ellemouton authored Jan 17, 2025
2 parents 211865a + 34d1f67 commit edaf59d
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 29 deletions.
16 changes: 16 additions & 0 deletions accounts/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"time"

"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -219,6 +220,21 @@ type Store interface {
// Accounts retrieves all accounts from the store and un-marshals them.
Accounts(ctx context.Context) ([]*OffChainBalanceAccount, error)

// UpdateAccountBalanceAndExpiry updates the balance and/or expiry of an
// account.
UpdateAccountBalanceAndExpiry(ctx context.Context, id AccountID,
newBalance fn.Option[lnwire.MilliSatoshi],
newExpiry fn.Option[time.Time]) error

// AddAccountInvoice adds an invoice hash to an account.
AddAccountInvoice(ctx context.Context, id AccountID,
hash lntypes.Hash) error

// IncreaseAccountBalance increases the balance of the account with the
// given ID by the given amount.
IncreaseAccountBalance(ctx context.Context, id AccountID,
amount lnwire.MilliSatoshi) error

// RemoveAccount finds an account by its ID and removes it from the¨
// store.
RemoveAccount(ctx context.Context, id AccountID) error
Expand Down
3 changes: 2 additions & 1 deletion accounts/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ func (s *RPCServer) UpdateAccount(ctx context.Context,

// Ask the service to update the account.
account, err := s.service.UpdateAccount(
ctx, accountID, req.AccountBalance, req.ExpirationDate,
ctx, accountID, btcutil.Amount(req.AccountBalance),
req.ExpirationDate,
)
if err != nil {
return nil, err
Expand Down
48 changes: 20 additions & 28 deletions accounts/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"sync"
"time"

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/taproot-assets/fn"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
Expand Down Expand Up @@ -299,7 +300,7 @@ func (s *InterceptorService) NewAccount(ctx context.Context,
// UpdateAccount writes an account to the database, overwriting the existing one
// if it exists.
func (s *InterceptorService) UpdateAccount(ctx context.Context,
accountID AccountID, accountBalance,
accountID AccountID, accountBalance btcutil.Amount,
expirationDate int64) (*OffChainBalanceAccount, error) {

s.Lock()
Expand All @@ -313,36 +314,35 @@ func (s *InterceptorService) UpdateAccount(ctx context.Context,
return nil, ErrAccountServiceDisabled
}

account, err := s.store.Account(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("error fetching account: %w", err)
}

// If the expiration date was set, parse it as a unix time stamp. A
// value of -1 signals "don't update the expiration date".
var expiry fn.Option[time.Time]
if expirationDate > 0 {
account.ExpirationDate = time.Unix(expirationDate, 0)
expiry = fn.Some(time.Unix(expirationDate, 0))
} else if expirationDate == 0 {
// Setting the expiration to 0 means don't expire in which case
// we use a zero time (zero unix time would still be 1970, so
// that doesn't work for us).
account.ExpirationDate = time.Time{}
expiry = fn.Some(time.Time{})
}

// If the new account balance was set, parse it as millisatoshis. A
// value of -1 signals "don't update the balance".
var balance fn.Option[lnwire.MilliSatoshi]
if accountBalance >= 0 {
// Convert from satoshis to millisatoshis for storage.
account.CurrentBalance = int64(accountBalance) * 1000
balance = fn.Some(lnwire.MilliSatoshi(accountBalance) * 1000)
}

// Create the actual account in the macaroon account store.
err = s.store.UpdateAccount(ctx, account)
err := s.store.UpdateAccountBalanceAndExpiry(
ctx, accountID, balance, expiry,
)
if err != nil {
return nil, fmt.Errorf("unable to update account: %w", err)
}

return account, nil
return s.store.Account(ctx, accountID)
}

// Account retrieves an account from the bolt DB and un-marshals it. If the
Expand Down Expand Up @@ -439,15 +439,15 @@ func (s *InterceptorService) AssociateInvoice(ctx context.Context, id AccountID,
s.Lock()
defer s.Unlock()

account, err := s.store.Account(ctx, id)
err := s.store.AddAccountInvoice(ctx, id, hash)
if err != nil {
return err
return fmt.Errorf("error adding invoice to account: %w", err)
}

account.Invoices[hash] = struct{}{}
// If the above was successful, then we update our in-memory map.
s.invoiceToAccount[hash] = id

return s.store.UpdateAccount(ctx, account)
return nil
}

// PaymentErrored removes a pending payment from the account's registered
Expand Down Expand Up @@ -599,21 +599,13 @@ func (s *InterceptorService) invoiceUpdate(ctx context.Context,
return nil
}

account, err := s.store.Account(ctx, acctID)
if err != nil {
return s.disableAndErrorfUnsafe(
"error fetching account: %w", err,
)
}

// If we get here, the current account has the invoice associated with
// it that was just paid. Credit the amount to the account and update it
// in the DB.
account.CurrentBalance += int64(invoice.AmountPaid)
if err := s.store.UpdateAccount(ctx, account); err != nil {
return s.disableAndErrorfUnsafe(
"error updating account: %w", err,
)
err := s.store.IncreaseAccountBalance(ctx, acctID, invoice.AmountPaid)
if err != nil {
return s.disableAndErrorfUnsafe("error increasing account "+
"balance account: %w", err)
}

// We've now fully processed the invoice and don't need to keep it
Expand Down
102 changes: 102 additions & 0 deletions accounts/store_kvdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"os"
"time"

"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"go.etcd.io/bbolt"
)
Expand Down Expand Up @@ -194,6 +197,92 @@ func (s *BoltStore) UpdateAccount(_ context.Context,
}, func() {})
}

// UpdateAccountBalanceAndExpiry updates the balance and/or expiry of an
// account.
//
// NOTE: This is part of the Store interface.
func (s *BoltStore) UpdateAccountBalanceAndExpiry(_ context.Context,
id AccountID, newBalance fn.Option[lnwire.MilliSatoshi],
newExpiry fn.Option[time.Time]) error {

update := func(account *OffChainBalanceAccount) error {
newBalance.WhenSome(func(balance lnwire.MilliSatoshi) {
account.CurrentBalance = int64(balance)
})
newExpiry.WhenSome(func(expiry time.Time) {
account.ExpirationDate = expiry
})

return nil
}

return s.updateAccount(id, update)
}

// AddAccountInvoice adds an invoice hash to the account with the given ID.
//
// NOTE: This is part of the Store interface.
func (s *BoltStore) AddAccountInvoice(_ context.Context, id AccountID,
hash lntypes.Hash) error {

update := func(account *OffChainBalanceAccount) error {
account.Invoices[hash] = struct{}{}

return nil
}

return s.updateAccount(id, update)
}

// IncreaseAccountBalance increases the balance of the account with the given ID
// by the given amount.
//
// NOTE: This is part of the Store interface.
func (s *BoltStore) IncreaseAccountBalance(_ context.Context, id AccountID,
amount lnwire.MilliSatoshi) error {

update := func(account *OffChainBalanceAccount) error {
if amount > math.MaxInt64 {
return fmt.Errorf("amount %d exceeds the maximum of %d",
amount, math.MaxInt64)
}

account.CurrentBalance += int64(amount)

return nil
}

return s.updateAccount(id, update)
}

func (s *BoltStore) updateAccount(id AccountID,
updateFn func(*OffChainBalanceAccount) error) error {

return s.db.Update(func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(accountBucketName)
if bucket == nil {
return ErrAccountBucketNotFound
}

account, err := getAccount(bucket, id)
if err != nil {
return fmt.Errorf("error fetching account, %w", err)
}

err = updateFn(account)
if err != nil {
return fmt.Errorf("error updating account, %w", err)
}

err = storeAccount(bucket, account)
if err != nil {
return fmt.Errorf("error storing account, %w", err)
}

return nil
}, func() {})
}

// storeAccount serializes and writes the given account to the given account
// bucket.
func storeAccount(accountBucket kvdb.RwBucket,
Expand All @@ -209,6 +298,19 @@ func storeAccount(accountBucket kvdb.RwBucket,
return accountBucket.Put(account.ID[:], accountBinary)
}

// getAccount retrieves an account from the given account bucket and
// deserializes it.
func getAccount(accountBucket kvdb.RwBucket, id AccountID) (
*OffChainBalanceAccount, error) {

accountBinary := accountBucket.Get(id[:])
if len(accountBinary) == 0 {
return nil, ErrAccNotFound
}

return deserializeAccount(accountBinary)
}

// uniqueRandomAccountID generates a new random ID and makes sure it does not
// yet exist in the DB.
func uniqueRandomAccountID(accountBucket kvdb.RBucket) (AccountID, error) {
Expand Down
Loading

0 comments on commit edaf59d

Please sign in to comment.