From 62e1a7eb618395919d8c025746b14cb18d528250 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 27 Dec 2024 17:03:49 +0200 Subject: [PATCH 1/8] accounts: rename store file We rename the `store.go` file to `store_kvdb.go` to indicate that this file contains the kvdb implementation of the accounts DB. This is in preparation for adding a sql-backed implementation later on. We do this early on in the PR so that any changes that need to be made during the review process can be easily addressed with fix-up commits that edit the newly named file. --- accounts/{store.go => store_kvdb.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename accounts/{store.go => store_kvdb.go} (100%) diff --git a/accounts/store.go b/accounts/store_kvdb.go similarity index 100% rename from accounts/store.go rename to accounts/store_kvdb.go From 86fc2ccf79e8e81e359d9d10c969367ab7fecf0a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 27 Dec 2024 17:24:01 +0200 Subject: [PATCH 2/8] accounts: let interface methods take a context Update the accounts `Store` and `Service` interfaces take a context. This is in preparation for when the backend DB of the accounts service is a SQL store which will have methods that take a context. --- accounts/checkers.go | 20 ++++--- accounts/checkers_test.go | 34 +++++++----- accounts/interceptor.go | 2 +- accounts/interface.go | 38 ++++++++----- accounts/rpcserver.go | 32 ++++++----- accounts/service.go | 112 +++++++++++++++++++++----------------- accounts/service_test.go | 66 +++++++++++----------- accounts/store_kvdb.go | 25 ++++++--- accounts/store_test.go | 29 +++++----- 9 files changed, 202 insertions(+), 156 deletions(-) diff --git a/accounts/checkers.go b/accounts/checkers.go index 00a41d393..0b99bd68a 100644 --- a/accounts/checkers.go +++ b/accounts/checkers.go @@ -131,7 +131,7 @@ func NewAccountChecker(service Service, } return nil, service.AssociateInvoice( - acct.ID, hash, + ctx, acct.ID, hash, ) }, mid.PassThroughErrorHandler, ), @@ -615,12 +615,12 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, fee := lnrpc.CalculateFeeLimit(limit, sendAmt) sendAmt += fee - err = service.CheckBalance(acct.ID, sendAmt) + err = service.CheckBalance(ctx, acct.ID, sendAmt) if err != nil { return fmt.Errorf("error validating account balance: %w", err) } - err = service.AssociatePayment(acct.ID, pHash, sendAmt) + err = service.AssociatePayment(ctx, acct.ID, pHash, sendAmt) if err != nil { return fmt.Errorf("error associating payment: %w", err) } @@ -661,11 +661,13 @@ func checkSendResponse(ctx context.Context, service Service, if status == lnrpc.Payment_FAILED { service.DeleteValues(reqID) - return nil, service.RemovePayment(hash) + return nil, service.RemovePayment(ctx, hash) } // If there is no immediate failure, make sure we track the payment. - err = service.TrackPayment(acct.ID, hash, lnwire.MilliSatoshi(fullAmt)) + err = service.TrackPayment( + ctx, acct.ID, hash, lnwire.MilliSatoshi(fullAmt), + ) if err != nil { return nil, err } @@ -713,12 +715,12 @@ func checkSendToRoute(ctx context.Context, service Service, paymentHash []byte, } sendAmt += fee - err = service.CheckBalance(acct.ID, sendAmt) + err = service.CheckBalance(ctx, acct.ID, sendAmt) if err != nil { return fmt.Errorf("error validating account balance: %w", err) } - err = service.AssociatePayment(acct.ID, hash, sendAmt) + err = service.AssociatePayment(ctx, acct.ID, hash, sendAmt) if err != nil { return fmt.Errorf("error associating payment with hash %s: %w", hash, err) @@ -749,7 +751,7 @@ func erroredPaymentHandler(service Service) mid.ErrorHandler { "hash: %s and amount: %d", reqVals.PaymentHash, reqVals.PaymentAmount) - err = service.PaymentErrored(acct.ID, reqVals.PaymentHash) + err = service.PaymentErrored(ctx, acct.ID, reqVals.PaymentHash) if err != nil { return nil, err } @@ -812,7 +814,7 @@ func sendToRouteHTLCResponseHandler(service Service) func(ctx context.Context, } err = service.TrackPayment( - acct.ID, reqValues.PaymentHash, + ctx, acct.ID, reqValues.PaymentHash, lnwire.MilliSatoshi(totalAmount), ) if err != nil { diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 964142380..1398f4d1b 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -71,7 +71,7 @@ func newMockService() *mockService { } } -func (m *mockService) CheckBalance(_ AccountID, +func (m *mockService) CheckBalance(_ context.Context, _ AccountID, wantBalance lnwire.MilliSatoshi) error { if wantBalance > m.acctBalanceMsat { @@ -81,24 +81,28 @@ func (m *mockService) CheckBalance(_ AccountID, return nil } -func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error { +func (m *mockService) AssociateInvoice(_ context.Context, id AccountID, + hash lntypes.Hash) error { + m.trackedInvoices[hash] = id return nil } -func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash, - amt lnwire.MilliSatoshi) error { +func (m *mockService) AssociatePayment(_ context.Context, id AccountID, + paymentHash lntypes.Hash, amt lnwire.MilliSatoshi) error { return nil } -func (m *mockService) PaymentErrored(id AccountID, hash lntypes.Hash) error { +func (m *mockService) PaymentErrored(_ context.Context, id AccountID, + hash lntypes.Hash) error { + return nil } -func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash, - amt lnwire.MilliSatoshi) error { +func (m *mockService) TrackPayment(_ context.Context, _ AccountID, + hash lntypes.Hash, amt lnwire.MilliSatoshi) error { m.trackedPayments[hash] = &PaymentEntry{ Status: lnrpc.Payment_UNKNOWN, @@ -108,7 +112,9 @@ func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash, return nil } -func (m *mockService) RemovePayment(hash lntypes.Hash) error { +func (m *mockService) RemovePayment(_ context.Context, + hash lntypes.Hash) error { + delete(m.trackedPayments, hash) return nil @@ -524,7 +530,7 @@ func testSendPayment(t *testing.T, uri string) { require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { - acct, err := service.Account(id) + acct, err := service.Account(ctx, id) require.NoError(t, err) require.Equal(t, expectedBalance, @@ -539,7 +545,7 @@ func testSendPayment(t *testing.T, uri string) { // Create an account and add it to the context. acct, err := service.NewAccount( - 5000, time.Now().Add(time.Hour), "test", + ctx, 5000, time.Now().Add(time.Hour), "test", ) require.NoError(t, err) @@ -720,7 +726,7 @@ func TestSendPaymentV2(t *testing.T) { require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { - acct, err := service.Account(id) + acct, err := service.Account(ctx, id) require.NoError(t, err) require.Equal(t, expectedBalance, @@ -735,7 +741,7 @@ func TestSendPaymentV2(t *testing.T) { // Create an account and add it to the context. acct, err := service.NewAccount( - 5000, time.Now().Add(time.Hour), "test", + ctx, 5000, time.Now().Add(time.Hour), "test", ) require.NoError(t, err) @@ -907,7 +913,7 @@ func TestSendToRouteV2(t *testing.T) { require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { - acct, err := service.Account(id) + acct, err := service.Account(ctx, id) require.NoError(t, err) require.Equal(t, expectedBalance, @@ -922,7 +928,7 @@ func TestSendToRouteV2(t *testing.T) { // Create an account and add it to the context. acct, err := service.NewAccount( - 5000, time.Now().Add(time.Hour), "test", + ctx, 5000, time.Now().Add(time.Hour), "test", ) require.NoError(t, err) diff --git a/accounts/interceptor.go b/accounts/interceptor.go index aa3d759f0..079f4ba07 100644 --- a/accounts/interceptor.go +++ b/accounts/interceptor.go @@ -84,7 +84,7 @@ func (s *InterceptorService) Intercept(ctx context.Context, "macaroon caveat") } - acct, err := s.Account(*acctID) + acct, err := s.Account(ctx, *acctID) if err != nil { return mid.RPCErrString( req, "error getting account %x: %v", acctID[:], err, diff --git a/accounts/interface.go b/accounts/interface.go index 68bb5ad5c..c9c1e78e0 100644 --- a/accounts/interface.go +++ b/accounts/interface.go @@ -1,6 +1,7 @@ package accounts import ( + "context" "encoding/hex" "errors" "fmt" @@ -201,30 +202,34 @@ var ( type Store interface { // NewAccount creates a new OffChainBalanceAccount with the given // balance and a randomly chosen ID. - NewAccount(balance lnwire.MilliSatoshi, expirationDate time.Time, - label string) (*OffChainBalanceAccount, error) + NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, + expirationDate time.Time, label string) ( + *OffChainBalanceAccount, error) // UpdateAccount writes an account to the database, overwriting the // existing one if it exists. - UpdateAccount(account *OffChainBalanceAccount) error + UpdateAccount(ctx context.Context, + account *OffChainBalanceAccount) error // Account retrieves an account from the Store and un-marshals it. If // the account cannot be found, then ErrAccNotFound is returned. - Account(id AccountID) (*OffChainBalanceAccount, error) + Account(ctx context.Context, id AccountID) (*OffChainBalanceAccount, + error) // Accounts retrieves all accounts from the store and un-marshals them. - Accounts() ([]*OffChainBalanceAccount, error) + Accounts(ctx context.Context) ([]*OffChainBalanceAccount, error) // RemoveAccount finds an account by its ID and removes it from the¨ // store. - RemoveAccount(id AccountID) error + RemoveAccount(ctx context.Context, id AccountID) error // LastIndexes returns the last invoice add and settle index or // ErrNoInvoiceIndexKnown if no indexes are known yet. - LastIndexes() (uint64, uint64, error) + LastIndexes(ctx context.Context) (uint64, uint64, error) // StoreLastIndexes stores the last invoice add and settle index. - StoreLastIndexes(addIndex, settleIndex uint64) error + StoreLastIndexes(ctx context.Context, addIndex, + settleIndex uint64) error // Close closes the underlying store. Close() error @@ -234,34 +239,37 @@ type Store interface { type Service interface { // CheckBalance ensures an account is valid and has a balance equal to // or larger than the amount that is required. - CheckBalance(id AccountID, requiredBalance lnwire.MilliSatoshi) error + CheckBalance(ctx context.Context, id AccountID, + requiredBalance lnwire.MilliSatoshi) error // AssociateInvoice associates a generated invoice with the given // account, making it possible for the account to be credited in case // the invoice is paid. - AssociateInvoice(id AccountID, hash lntypes.Hash) error + AssociateInvoice(ctx context.Context, id AccountID, + hash lntypes.Hash) error // TrackPayment adds a new payment to be tracked to the service. If the // payment is eventually settled, its amount needs to be debited from // the given account. - TrackPayment(id AccountID, hash lntypes.Hash, + TrackPayment(ctx context.Context, id AccountID, hash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error // RemovePayment removes a failed payment from the service because it no // longer needs to be tracked. The payment is certain to never succeed, // so we never need to debit the amount from the account. - RemovePayment(hash lntypes.Hash) error + RemovePayment(ctx context.Context, hash lntypes.Hash) error // AssociatePayment associates a payment (hash) with the given account, // ensuring that the payment will be tracked for a user when LiT is // restarted. - AssociatePayment(id AccountID, paymentHash lntypes.Hash, - fullAmt lnwire.MilliSatoshi) error + AssociatePayment(ctx context.Context, id AccountID, + paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error // PaymentErrored removes a pending payment from the accounts // registered payment list. This should only ever be called if we are // sure that the payment request errored out. - PaymentErrored(id AccountID, hash lntypes.Hash) error + PaymentErrored(ctx context.Context, id AccountID, + hash lntypes.Hash) error RequestValuesStore } diff --git a/accounts/rpcserver.go b/accounts/rpcserver.go index 22135f634..5c1314820 100644 --- a/accounts/rpcserver.go +++ b/accounts/rpcserver.go @@ -71,7 +71,7 @@ func (s *RPCServer) CreateAccount(ctx context.Context, // Create the actual account in the macaroon account store. account, err := s.service.NewAccount( - balanceMsat, expirationDate, req.Label, + ctx, balanceMsat, expirationDate, req.Label, ) if err != nil { return nil, fmt.Errorf("unable to create account: %w", err) @@ -109,20 +109,20 @@ func (s *RPCServer) CreateAccount(ctx context.Context, } // UpdateAccount updates an existing account in the account database. -func (s *RPCServer) UpdateAccount(_ context.Context, +func (s *RPCServer) UpdateAccount(ctx context.Context, req *litrpc.UpdateAccountRequest) (*litrpc.Account, error) { log.Infof("[updateaccount] id=%s, label=%v, balance=%d, expiration=%d", req.Id, req.Label, req.AccountBalance, req.ExpirationDate) - accountID, err := s.findAccount(req.Id, req.Label) + accountID, err := s.findAccount(ctx, req.Id, req.Label) if err != nil { return nil, err } // Ask the service to update the account. account, err := s.service.UpdateAccount( - accountID, req.AccountBalance, req.ExpirationDate, + ctx, accountID, req.AccountBalance, req.ExpirationDate, ) if err != nil { return nil, err @@ -133,13 +133,13 @@ func (s *RPCServer) UpdateAccount(_ context.Context, // ListAccounts returns all accounts that are currently stored in the account // database. -func (s *RPCServer) ListAccounts(context.Context, - *litrpc.ListAccountsRequest) (*litrpc.ListAccountsResponse, error) { +func (s *RPCServer) ListAccounts(ctx context.Context, + _ *litrpc.ListAccountsRequest) (*litrpc.ListAccountsResponse, error) { log.Info("[listaccounts]") // Retrieve all accounts from the macaroon account store. - accts, err := s.service.Accounts() + accts, err := s.service.Accounts(ctx) if err != nil { return nil, fmt.Errorf("unable to list accounts: %w", err) } @@ -158,17 +158,17 @@ func (s *RPCServer) ListAccounts(context.Context, } // AccountInfo returns the account with the given ID or label. -func (s *RPCServer) AccountInfo(_ context.Context, +func (s *RPCServer) AccountInfo(ctx context.Context, req *litrpc.AccountInfoRequest) (*litrpc.Account, error) { log.Infof("[accountinfo] id=%v, label=%v", req.Id, req.Label) - accountID, err := s.findAccount(req.Id, req.Label) + accountID, err := s.findAccount(ctx, req.Id, req.Label) if err != nil { return nil, err } - dbAccount, err := s.service.Account(accountID) + dbAccount, err := s.service.Account(ctx, accountID) if err != nil { return nil, fmt.Errorf("error retrieving account: %w", err) } @@ -177,19 +177,19 @@ func (s *RPCServer) AccountInfo(_ context.Context, } // RemoveAccount removes the given account from the account database. -func (s *RPCServer) RemoveAccount(_ context.Context, +func (s *RPCServer) RemoveAccount(ctx context.Context, req *litrpc.RemoveAccountRequest) (*litrpc.RemoveAccountResponse, error) { log.Infof("[removeaccount] id=%v, label=%v", req.Id, req.Label) - accountID, err := s.findAccount(req.Id, req.Label) + accountID, err := s.findAccount(ctx, req.Id, req.Label) if err != nil { return nil, err } // Now remove the account. - err = s.service.RemoveAccount(accountID) + err = s.service.RemoveAccount(ctx, accountID) if err != nil { return nil, fmt.Errorf("error removing account: %w", err) } @@ -198,7 +198,9 @@ func (s *RPCServer) RemoveAccount(_ context.Context, } // findAccount finds an account by its ID or label. -func (s *RPCServer) findAccount(id string, label string) (AccountID, error) { +func (s *RPCServer) findAccount(ctx context.Context, id string, label string) ( + AccountID, error) { + switch { case id != "" && label != "": return AccountID{}, fmt.Errorf("either account ID or label " + @@ -219,7 +221,7 @@ func (s *RPCServer) findAccount(id string, label string) (AccountID, error) { case label != "": // We need to find the account by its label. - accounts, err := s.service.Accounts() + accounts, err := s.service.Accounts(ctx) if err != nil { return AccountID{}, fmt.Errorf("unable to list "+ "accounts: %w", err) diff --git a/accounts/service.go b/accounts/service.go index 5db9ae872..42f4e08b9 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -114,7 +114,7 @@ func (s *InterceptorService) Start(ctx context.Context, // Let's first fill our cache that maps invoices to accounts, which // allows us to credit an account easily once an invoice is settled. We // also track payments that aren't in a final state yet. - existingAccounts, err := s.store.Accounts() + existingAccounts, err := s.store.Accounts(ctx) if err != nil { return s.disableAndErrorf("error querying existing "+ "accounts: %w", err) @@ -132,7 +132,7 @@ func (s *InterceptorService) Start(ctx context.Context, entry := entry if !successState(entry.Status) { err := s.TrackPayment( - acct.ID, hash, entry.FullAmount, + ctx, acct.ID, hash, entry.FullAmount, ) if err != nil { return s.disableAndErrorf("error "+ @@ -145,7 +145,7 @@ func (s *InterceptorService) Start(ctx context.Context, // First ask our DB about the highest indexes we know. If this is the // first startup then the ErrNoInvoiceIndexKnown error is returned, and // we know we need to do a lookup. - s.currentAddIndex, s.currentSettleIndex, err = s.store.LastIndexes() + s.currentAddIndex, s.currentSettleIndex, err = s.store.LastIndexes(ctx) switch err { case nil: // All good, we stored indexes in the DB, use those values. @@ -193,7 +193,8 @@ func (s *InterceptorService) Start(ctx context.Context, return } - if err := s.invoiceUpdate(invoice); err != nil { + err := s.invoiceUpdate(ctx, invoice) + if err != nil { log.Errorf("Error processing invoice "+ "update: %v", err) @@ -289,19 +290,21 @@ func (s *InterceptorService) disableAndErrorfUnsafe(format string, // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. -func (s *InterceptorService) NewAccount(balance lnwire.MilliSatoshi, +func (s *InterceptorService) NewAccount(ctx context.Context, + balance lnwire.MilliSatoshi, expirationDate time.Time, label string) (*OffChainBalanceAccount, error) { s.Lock() defer s.Unlock() - return s.store.NewAccount(balance, expirationDate, label) + return s.store.NewAccount(ctx, balance, expirationDate, label) } // UpdateAccount writes an account to the database, overwriting the existing one // if it exists. -func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, +func (s *InterceptorService) UpdateAccount(ctx context.Context, + accountID AccountID, accountBalance, expirationDate int64) (*OffChainBalanceAccount, error) { s.Lock() @@ -315,7 +318,7 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, return nil, ErrAccountServiceDisabled } - account, err := s.store.Account(accountID) + account, err := s.store.Account(ctx, accountID) if err != nil { return nil, fmt.Errorf("error fetching account: %w", err) } @@ -339,7 +342,7 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, } // Create the actual account in the macaroon account store. - err = s.store.UpdateAccount(account) + err = s.store.UpdateAccount(ctx, account) if err != nil { return nil, fmt.Errorf("unable to update account: %w", err) } @@ -349,25 +352,29 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, // Account retrieves an account from the bolt DB and un-marshals it. If the // account cannot be found, then ErrAccNotFound is returned. -func (s *InterceptorService) Account(id AccountID) (*OffChainBalanceAccount, - error) { +func (s *InterceptorService) Account(ctx context.Context, + id AccountID) (*OffChainBalanceAccount, error) { s.RLock() defer s.RUnlock() - return s.store.Account(id) + return s.store.Account(ctx, id) } // Accounts retrieves all accounts from the bolt DB and un-marshals them. -func (s *InterceptorService) Accounts() ([]*OffChainBalanceAccount, error) { +func (s *InterceptorService) Accounts(ctx context.Context) ( + []*OffChainBalanceAccount, error) { + s.RLock() defer s.RUnlock() - return s.store.Accounts() + return s.store.Accounts(ctx) } // RemoveAccount finds an account by its ID and removes it from the DB. -func (s *InterceptorService) RemoveAccount(id AccountID) error { +func (s *InterceptorService) RemoveAccount(ctx context.Context, + id AccountID) error { + s.Lock() defer s.Unlock() @@ -378,18 +385,18 @@ func (s *InterceptorService) RemoveAccount(id AccountID) error { } // Let's remove the payment (which also cancels the tracking). - err := s.removePayment(hash, lnrpc.Payment_FAILED) + err := s.removePayment(ctx, hash, lnrpc.Payment_FAILED) if err != nil { return err } } - return s.store.RemoveAccount(id) + return s.store.RemoveAccount(ctx, id) } // CheckBalance ensures an account is valid and has a balance equal to or larger // than the amount that is required. -func (s *InterceptorService) CheckBalance(id AccountID, +func (s *InterceptorService) CheckBalance(ctx context.Context, id AccountID, requiredBalance lnwire.MilliSatoshi) error { s.RLock() @@ -397,7 +404,7 @@ func (s *InterceptorService) CheckBalance(id AccountID, // Check that the account exists, it hasn't expired and has sufficient // balance. - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } @@ -431,26 +438,27 @@ func calcAvailableAccountBalance(account *OffChainBalanceAccount) int64 { // AssociateInvoice associates a generated invoice with the given account, // making it possible for the account to be credited in case the invoice is // paid. -func (s *InterceptorService) AssociateInvoice(id AccountID, +func (s *InterceptorService) AssociateInvoice(ctx context.Context, id AccountID, hash lntypes.Hash) error { s.Lock() defer s.Unlock() - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } account.Invoices[hash] = struct{}{} s.invoiceToAccount[hash] = id - return s.store.UpdateAccount(account) + + return s.store.UpdateAccount(ctx, account) } // PaymentErrored removes a pending payment from the account's registered // payment list. This should only ever be called if we are sure that the payment // request errored out. -func (s *InterceptorService) PaymentErrored(id AccountID, +func (s *InterceptorService) PaymentErrored(ctx context.Context, id AccountID, hash lntypes.Hash) error { s.Lock() @@ -464,7 +472,7 @@ func (s *InterceptorService) PaymentErrored(id AccountID, "has already started") } - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } @@ -479,7 +487,7 @@ func (s *InterceptorService) PaymentErrored(id AccountID, // Delete the payment and update the persisted account. delete(account.Payments, hash) - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { return fmt.Errorf("error updating account: %w", err) } @@ -489,13 +497,13 @@ func (s *InterceptorService) PaymentErrored(id AccountID, // AssociatePayment associates a payment (hash) with the given account, // ensuring that the payment will be tracked for a user when LiT is // restarted. -func (s *InterceptorService) AssociatePayment(id AccountID, +func (s *InterceptorService) AssociatePayment(ctx context.Context, id AccountID, paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error { s.Lock() defer s.Unlock() - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } @@ -528,7 +536,7 @@ func (s *InterceptorService) AssociatePayment(id AccountID, FullAmount: fullAmt, } - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { return fmt.Errorf("error updating account: %w", err) } @@ -543,7 +551,9 @@ func (s *InterceptorService) AssociatePayment(id AccountID, // the same lock. Else we risk that other threads will try to update invoices // while the service should be disabled, which could lead to us missing invoice // updates on next startup. -func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { +func (s *InterceptorService) invoiceUpdate(ctx context.Context, + invoice *lndclient.Invoice) error { + s.Lock() defer s.Unlock() @@ -572,7 +582,7 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { if needUpdate { err := s.store.StoreLastIndexes( - s.currentAddIndex, s.currentSettleIndex, + ctx, s.currentAddIndex, s.currentSettleIndex, ) if err != nil { return s.disableAndErrorfUnsafe( @@ -594,7 +604,7 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { return nil } - account, err := s.store.Account(acctID) + account, err := s.store.Account(ctx, acctID) if err != nil { return s.disableAndErrorfUnsafe( "error fetching account: %w", err, @@ -605,7 +615,7 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { // it that was just paid. Credit the amount to the account and update it // in the DB. account.CurrentBalance += int64(invoice.AmountPaid) - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { return s.disableAndErrorfUnsafe( "error updating account: %w", err, ) @@ -620,8 +630,8 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { // TrackPayment adds a new payment to be tracked to the service. If the payment // is eventually settled, its amount needs to be debited from the given account. -func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, - fullAmt lnwire.MilliSatoshi) error { +func (s *InterceptorService) TrackPayment(ctx context.Context, id AccountID, + hash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error { s.Lock() defer s.Unlock() @@ -634,7 +644,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // Similarly, if we've already processed the payment in the past, there // is a reference in the account with the given state. - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return fmt.Errorf("error fetching account: %w", err) } @@ -658,7 +668,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, FullAmount: fullAmt, } - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { if !ok { // In the rare case that the payment isn't associated // with an account yet, and we fail to update the @@ -718,7 +728,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, select { case paymentUpdate := <-statusChan: terminalState, err := s.paymentUpdate( - hash, paymentUpdate, + s.mainCtx, hash, paymentUpdate, ) if err != nil { s.mainErrCallback(err) @@ -746,7 +756,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // seen as in-flight balance when // calculating the account's available // balance. - err := s.RemovePayment(hash) + err := s.RemovePayment(ctx, hash) if err != nil { // We don't disable the service // here, as the worst that can @@ -789,8 +799,8 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // NOTE: Any code that errors in this function MUST call disableAndErrorfUnsafe // while the store lock is held to ensure that the service is disabled under // the same lock. -func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, - status lndclient.PaymentStatus) (bool, error) { +func (s *InterceptorService) paymentUpdate(ctx context.Context, + hash lntypes.Hash, status lndclient.PaymentStatus) (bool, error) { // Are we still in-flight? Then we don't have to do anything just yet. // The unknown state should never happen in practice but if it ever did @@ -824,7 +834,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // A failed payment can just be removed, no further action needed. if status.State == lnrpc.Payment_FAILED { - err := s.removePayment(hash, status.State) + err := s.removePayment(ctx, hash, status.State) if err != nil { err = s.disableAndErrorfUnsafe("error removing "+ "payment: %w", err) @@ -835,7 +845,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // The payment went through! We now need to debit the full amount from // the account. - account, err := s.store.Account(pendingPayment.accountID) + account, err := s.store.Account(ctx, pendingPayment.accountID) if err != nil { err = s.disableAndErrorfUnsafe("error fetching account: %w", err) @@ -851,7 +861,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, Status: lnrpc.Payment_SUCCEEDED, FullAmount: fullAmount, } - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { err = s.disableAndErrorfUnsafe("error updating account: %w", err) @@ -860,7 +870,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // We've now fully processed the payment and don't need to keep it // mapped or tracked anymore. - err = s.removePayment(hash, lnrpc.Payment_SUCCEEDED) + err = s.removePayment(ctx, hash, lnrpc.Payment_SUCCEEDED) if err != nil { err = s.disableAndErrorfUnsafe("error removing payment: %w", err) @@ -872,19 +882,21 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // RemovePayment removes a failed payment from the service because it no longer // needs to be tracked. The payment is certain to never succeed, so we never // need to debit the amount from the account. -func (s *InterceptorService) RemovePayment(hash lntypes.Hash) error { +func (s *InterceptorService) RemovePayment(ctx context.Context, + hash lntypes.Hash) error { + s.Lock() defer s.Unlock() - return s.removePayment(hash, lnrpc.Payment_FAILED) + return s.removePayment(ctx, hash, lnrpc.Payment_FAILED) } // removePayment stops tracking a payment and updates the status in the account // to the given status. // // NOTE: The store lock MUST be held when calling this method. -func (s *InterceptorService) removePayment(hash lntypes.Hash, - status lnrpc.Payment_PaymentStatus) error { +func (s *InterceptorService) removePayment(ctx context.Context, + hash lntypes.Hash, status lnrpc.Payment_PaymentStatus) error { // It could be that we haven't actually started tracking the payment // yet, so if we can't find it, we just do nothing. @@ -893,7 +905,7 @@ func (s *InterceptorService) removePayment(hash lntypes.Hash, return nil } - account, err := s.store.Account(pendingPayment.accountID) + account, err := s.store.Account(ctx, pendingPayment.accountID) if err != nil { return err } @@ -909,7 +921,7 @@ func (s *InterceptorService) removePayment(hash lntypes.Hash, // If we did, let's set the status correctly in the DB now. account.Payments[hash].Status = status - return s.store.UpdateAccount(account) + return s.store.UpdateAccount(ctx, account) } // successState returns true if a payment was completed successfully. diff --git a/accounts/service_test.go b/accounts/service_test.go index b38b119a4..946a8f292 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -197,6 +197,7 @@ func (r *mockRouter) TrackPayment(_ context.Context, // invoices of account related calls correctly. func TestAccountService(t *testing.T) { t.Parallel() + ctx := context.Background() testCases := []struct { name string @@ -233,7 +234,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -279,7 +280,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -314,7 +315,7 @@ func TestAccountService(t *testing.T) { Payments: make(AccountPayments), } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) s.mainErrCallback(testErr) @@ -331,7 +332,7 @@ func TestAccountService(t *testing.T) { s *InterceptorService) { acct, err := s.store.NewAccount( - 1234, testExpiration, "", + ctx, 1234, testExpiration, "", ) require.NoError(t, err) @@ -341,7 +342,7 @@ func TestAccountService(t *testing.T) { FullAmount: 1234, } - err = s.store.UpdateAccount(acct) + err = s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -373,7 +374,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) r.trackPaymentErr = testErr @@ -410,7 +411,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -463,7 +464,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -516,7 +517,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -540,7 +541,7 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) return acct.CurrentBalance == 3000 @@ -556,7 +557,7 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) if len(acct.Payments) != 3 { @@ -582,10 +583,10 @@ func TestAccountService(t *testing.T) { // First check that the account has an available balance // of 1000. That means that the payment with testHash3 // and amount 2000 is still considered to be in-flight. - err := s.CheckBalance(testID, 1000) + err := s.CheckBalance(ctx, testID, 1000) require.NoError(t, err) - err = s.CheckBalance(testID, 1001) + err = s.CheckBalance(ctx, testID, 1001) require.ErrorIs(t, err, ErrAccBalanceInsufficient) // Now signal that the payment was non-initiated. @@ -595,8 +596,8 @@ func TestAccountService(t *testing.T) { // goroutine, and therefore free up the 2000 in-flight // balance. assertEventually(t, func() bool { - bal3000Err := s.CheckBalance(testID, 3000) - bal3001Err := s.CheckBalance(testID, 3001) + bal3000Err := s.CheckBalance(ctx, testID, 3000) + bal3001Err := s.CheckBalance(ctx, testID, 3001) require.ErrorIs( t, bal3001Err, ErrAccBalanceInsufficient, @@ -606,7 +607,7 @@ func TestAccountService(t *testing.T) { // Ensure that the payment is also set to the // failed status. - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) p, ok := acct.Payments[testHash3] @@ -626,7 +627,7 @@ func TestAccountService(t *testing.T) { setup: func(t *testing.T, lnd *mockLnd, r *mockRouter, s *InterceptorService) { - err := s.store.StoreLastIndexes(987_654, 555_555) + err := s.store.StoreLastIndexes(ctx, 987_654, 555_555) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -645,7 +646,9 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - addIdx, settleIdx, err := s.store.LastIndexes() + addIdx, settleIdx, err := s.store.LastIndexes( + ctx, + ) require.NoError(t, err) if addIdx != 987_654 { @@ -662,7 +665,9 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - addIdx, settleIdx, err := s.store.LastIndexes() + addIdx, settleIdx, err := s.store.LastIndexes( + ctx, + ) require.NoError(t, err) if addIdx != 1_000_000 { @@ -688,7 +693,7 @@ func TestAccountService(t *testing.T) { Payments: make(AccountPayments), } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -705,7 +710,7 @@ func TestAccountService(t *testing.T) { // Make sure the amount paid is eventually credited. assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) return acct.CurrentBalance == 1000 @@ -723,7 +728,7 @@ func TestAccountService(t *testing.T) { // Ensure that the balance now adds up to the sum of // both invoices. assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) return acct.CurrentBalance == (1000 + 777) @@ -757,7 +762,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) // The second account has one in-flight payment of 4k @@ -777,7 +782,7 @@ func TestAccountService(t *testing.T) { }, } - err = s.store.UpdateAccount(acct2) + err = s.store.UpdateAccount(ctx, acct2) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -787,11 +792,11 @@ func TestAccountService(t *testing.T) { // with an amount smaller or equal to 2k msats. This // also asserts that the second accounts in-flight // payment doesn't affect the first account. - err := s.CheckBalance(testID, 2000) + err := s.CheckBalance(ctx, testID, 2000) require.NoError(t, err) // But exactly one sat over it should fail. - err = s.CheckBalance(testID, 2001) + err = s.CheckBalance(ctx, testID, 2001) require.ErrorIs(t, err, ErrAccBalanceInsufficient) // Remove one of the payments (to simulate it failed) @@ -802,17 +807,17 @@ func TestAccountService(t *testing.T) { // We should now have up to 4k msats available. assertEventually(t, func() bool { - err = s.CheckBalance(testID, 4000) + err = s.CheckBalance(ctx, testID, 4000) return err == nil }) // The second account should be able to initiate a // payment of 1k msats. - err = s.CheckBalance(testID2, 1000) + err = s.CheckBalance(ctx, testID2, 1000) require.NoError(t, err) // But exactly one sat over it should fail. - err = s.CheckBalance(testID2, 1001) + err = s.CheckBalance(ctx, testID2, 1001) require.ErrorIs(t, err, ErrAccBalanceInsufficient) }, }} @@ -839,8 +844,7 @@ func TestAccountService(t *testing.T) { // Any errors during startup expected? err = service.Start( - context.Background(), lndMock, routerMock, - chainParams, + ctx, lndMock, routerMock, chainParams, ) if tc.startupErr != "" { require.ErrorContains(tt, err, tc.startupErr) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index ebaf937be..e4e22cb7d 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -2,6 +2,7 @@ package accounts import ( "bytes" + "context" "crypto/rand" "encoding/binary" "encoding/hex" @@ -103,7 +104,7 @@ func (s *BoltStore) Close() error { // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. -func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, +func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, expirationDate time.Time, label string) (*OffChainBalanceAccount, error) { @@ -120,7 +121,7 @@ func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, label) } - accounts, err := s.Accounts() + accounts, err := s.Accounts(ctx) if err != nil { return nil, fmt.Errorf("error checking label "+ "uniqueness: %w", err) @@ -174,7 +175,9 @@ func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, // UpdateAccount writes an account to the database, overwriting the existing one // if it exists. -func (s *BoltStore) UpdateAccount(account *OffChainBalanceAccount) error { +func (s *BoltStore) UpdateAccount(_ context.Context, + account *OffChainBalanceAccount) error { + return s.db.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { @@ -225,7 +228,9 @@ func uniqueRandomAccountID(accountBucket kvdb.RBucket) (AccountID, error) { // Account retrieves an account from the bolt DB and un-marshals it. If the // account cannot be found, then ErrAccNotFound is returned. -func (s *BoltStore) Account(id AccountID) (*OffChainBalanceAccount, error) { +func (s *BoltStore) Account(_ context.Context, id AccountID) ( + *OffChainBalanceAccount, error) { + // Try looking up and reading the account by its ID from the local // bolt DB. var accountBinary []byte @@ -259,7 +264,9 @@ func (s *BoltStore) Account(id AccountID) (*OffChainBalanceAccount, error) { } // Accounts retrieves all accounts from the bolt DB and un-marshals them. -func (s *BoltStore) Accounts() ([]*OffChainBalanceAccount, error) { +func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, + error) { + var accounts []*OffChainBalanceAccount err := s.db.View(func(tx kvdb.RTx) error { // This function will be called in the ForEach and receive @@ -302,7 +309,7 @@ func (s *BoltStore) Accounts() ([]*OffChainBalanceAccount, error) { } // RemoveAccount finds an account by its ID and removes it from the DB. -func (s *BoltStore) RemoveAccount(id AccountID) error { +func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { return s.db.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { @@ -320,7 +327,7 @@ func (s *BoltStore) RemoveAccount(id AccountID) error { // LastIndexes returns the last invoice add and settle index or // ErrNoInvoiceIndexKnown if no indexes are known yet. -func (s *BoltStore) LastIndexes() (uint64, uint64, error) { +func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { var ( addValue, settleValue []byte ) @@ -352,7 +359,9 @@ func (s *BoltStore) LastIndexes() (uint64, uint64, error) { } // StoreLastIndexes stores the last invoice add and settle index. -func (s *BoltStore) StoreLastIndexes(addIndex, settleIndex uint64) error { +func (s *BoltStore) StoreLastIndexes(_ context.Context, addIndex, + settleIndex uint64) error { + addValue := make([]byte, 8) settleValue := make([]byte, 8) byteOrder.PutUint64(addValue, addIndex) diff --git a/accounts/store_test.go b/accounts/store_test.go index 2f661febc..7d00d0f3c 100644 --- a/accounts/store_test.go +++ b/accounts/store_test.go @@ -1,6 +1,7 @@ package accounts import ( + "context" "testing" "time" @@ -12,26 +13,27 @@ import ( // TestAccountStore tests that accounts can be stored and retrieved correctly. func TestAccountStore(t *testing.T) { t.Parallel() + ctx := context.Background() store, err := NewBoltStore(t.TempDir(), DBFilename) require.NoError(t, err) // Create an account that does not expire. - acct1, err := store.NewAccount(0, time.Time{}, "foo") + acct1, err := store.NewAccount(ctx, 0, time.Time{}, "foo") require.NoError(t, err) require.False(t, acct1.HasExpired()) - dbAccount, err := store.Account(acct1.ID) + dbAccount, err := store.Account(ctx, acct1.ID) require.NoError(t, err) assertEqualAccounts(t, acct1, dbAccount) // Make sure we cannot create a second account with the same label. - _, err = store.NewAccount(123, time.Time{}, "foo") + _, err = store.NewAccount(ctx, 123, time.Time{}, "foo") require.ErrorContains(t, err, "account with the label 'foo' already") // Make sure we cannot set a label that looks like an account ID. - _, err = store.NewAccount(123, time.Time{}, "0011223344556677") + _, err = store.NewAccount(ctx, 123, time.Time{}, "0011223344556677") require.ErrorContains(t, err, "is not allowed as it can be mistaken") // Update all values of the account that we can modify. @@ -47,10 +49,10 @@ func TestAccountStore(t *testing.T) { } acct1.Invoices[lntypes.Hash{12, 34, 56, 78}] = struct{}{} acct1.Invoices[lntypes.Hash{34, 56, 78, 90}] = struct{}{} - err = store.UpdateAccount(acct1) + err = store.UpdateAccount(ctx, acct1) require.NoError(t, err) - dbAccount, err = store.Account(acct1.ID) + dbAccount, err = store.Account(ctx, acct1.ID) require.NoError(t, err) assertEqualAccounts(t, acct1, dbAccount) @@ -62,18 +64,18 @@ func TestAccountStore(t *testing.T) { require.True(t, acct1.HasExpired()) // Test listing and deleting accounts. - accounts, err := store.Accounts() + accounts, err := store.Accounts(ctx) require.NoError(t, err) require.Len(t, accounts, 1) - err = store.RemoveAccount(acct1.ID) + err = store.RemoveAccount(ctx, acct1.ID) require.NoError(t, err) - accounts, err = store.Accounts() + accounts, err = store.Accounts(ctx) require.NoError(t, err) require.Len(t, accounts, 0) - _, err = store.Account(acct1.ID) + _, err = store.Account(ctx, acct1.ID) require.ErrorIs(t, err, ErrAccNotFound) } @@ -108,16 +110,17 @@ func assertEqualAccounts(t *testing.T, expected, // stored and retrieved correctly. func TestLastInvoiceIndexes(t *testing.T) { t.Parallel() + ctx := context.Background() store, err := NewBoltStore(t.TempDir(), DBFilename) require.NoError(t, err) - _, _, err = store.LastIndexes() + _, _, err = store.LastIndexes(ctx) require.ErrorIs(t, err, ErrNoInvoiceIndexKnown) - require.NoError(t, store.StoreLastIndexes(7, 99)) + require.NoError(t, store.StoreLastIndexes(ctx, 7, 99)) - add, settle, err := store.LastIndexes() + add, settle, err := store.LastIndexes(ctx) require.NoError(t, err) require.EqualValues(t, 7, add) require.EqualValues(t, 99, settle) From 82eeadd819ad997e7bfcb0cf46a7af1fca2e42a3 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 27 Dec 2024 19:58:28 +0200 Subject: [PATCH 3/8] accounts: pass Store impl to NewService We want to be able to pass different DB implementations to NewService. In preparation for this, we make it implementation agnostic by letting it take a `Store` instead of constructing one itself. This this change, we also let LiT handle the closing of the accounts Store instead of the accounts service --- accounts/checkers_test.go | 12 +++++++++--- accounts/service.go | 13 ++++--------- accounts/service_test.go | 6 ++++-- terminal.go | 18 +++++++++++++++++- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 1398f4d1b..d5c772461 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -523,7 +523,9 @@ func testSendPayment(t *testing.T, uri string) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store, err := NewBoltStore(t.TempDir(), DBFilename) + require.NoError(t, err) + service, err := NewService(store, errFunc) require.NoError(t, err) err = service.Start(ctx, lndMock, routerMock, chainParams) @@ -719,7 +721,9 @@ func TestSendPaymentV2(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store, err := NewBoltStore(t.TempDir(), DBFilename) + require.NoError(t, err) + service, err := NewService(store, errFunc) require.NoError(t, err) err = service.Start(ctx, lndMock, routerMock, chainParams) @@ -906,7 +910,9 @@ func TestSendToRouteV2(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store, err := NewBoltStore(t.TempDir(), DBFilename) + require.NoError(t, err) + service, err := NewService(store, errFunc) require.NoError(t, err) err = service.Start(ctx, lndMock, routerMock, chainParams) diff --git a/accounts/service.go b/accounts/service.go index 42f4e08b9..820dad23e 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -78,16 +78,11 @@ type InterceptorService struct { // NewService returns a service backed by the macaroon Bolt DB stored in the // passed-in directory. -func NewService(dir string, - errCallback func(error)) (*InterceptorService, error) { - - accountStore, err := NewBoltStore(dir, DBFilename) - if err != nil { - return nil, err - } +func NewService(store Store, errCallback func(error)) (*InterceptorService, + error) { return &InterceptorService{ - store: accountStore, + store: store, invoiceToAccount: make(map[lntypes.Hash]AccountID), pendingPayments: make(map[lntypes.Hash]*trackedPayment), requestValuesStore: newRequestValuesStore(), @@ -242,7 +237,7 @@ func (s *InterceptorService) Stop() error { close(s.quit) s.wg.Wait() - return s.store.Close() + return nil } // IsRunning checks if the account service is running, and returns a boolean diff --git a/accounts/service_test.go b/accounts/service_test.go index 946a8f292..8834a3a06 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -243,7 +243,7 @@ func TestAccountService(t *testing.T) { // Start by closing the store. This should cause an // error once we make an invoice update, as the service // will fail when persisting the invoice update. - s.store.Close() + require.NoError(t, s.store.Close()) // Ensure that the service was started successfully and // still running though, despite the closing of the @@ -833,7 +833,9 @@ func TestAccountService(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store, err := NewBoltStore(t.TempDir(), DBFilename) + require.NoError(tt, err) + service, err := NewService(store, errFunc) require.NoError(t, err) // Is a setup call required to initialize initial diff --git a/terminal.go b/terminal.go index c92827ad4..84a7f42a4 100644 --- a/terminal.go +++ b/terminal.go @@ -214,6 +214,7 @@ type LightningTerminal struct { middleware *mid.Manager middlewareStarted bool + accountsStore *accounts.BoltStore accountService *accounts.InterceptorService accountServiceStarted bool @@ -412,8 +413,15 @@ func (g *LightningTerminal) start(ctx context.Context) error { ) } + g.accountsStore, err = accounts.NewBoltStore( + filepath.Dir(g.cfg.MacaroonPath), accounts.DBFilename, + ) + if err != nil { + return fmt.Errorf("error creating accounts store: %w", err) + } + g.accountService, err = accounts.NewService( - filepath.Dir(g.cfg.MacaroonPath), accountServiceErrCallback, + g.accountsStore, accountServiceErrCallback, ) if err != nil { return fmt.Errorf("error creating account service: %v", err) @@ -1421,6 +1429,14 @@ func (g *LightningTerminal) shutdownSubServers() error { } } + if g.accountsStore != nil { + err = g.accountsStore.Close() + if err != nil { + log.Errorf("Error closing accounts store: %v", err) + returnErr = err + } + } + if g.middlewareStarted { g.middleware.Stop() } From 51d864a24c89442427ca7d2e483ae41a7b8cebbb Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 4 Jan 2025 13:29:42 +0200 Subject: [PATCH 4/8] accounts: DB constructors for tests This commit adds two new test helpers, NewTestDB and NewTestDBFromPath in a file that is only built when the test_db_postgres and test_db_sqlite build flags are not set. When we add sql backends, we will add helpers with the same names for each new backend. We will then use the appropriate build flags to run our unit tests against all backends. --- accounts/checkers_test.go | 9 +++------ accounts/service_test.go | 5 ++--- accounts/store_test.go | 8 +++----- accounts/test_kvdb.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 38 insertions(+), 14 deletions(-) create mode 100644 accounts/test_kvdb.go diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index d5c772461..8c8c6c763 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -523,8 +523,7 @@ func testSendPayment(t *testing.T, uri string) { errFunc := func(err error) { lndMock.mainErrChan <- err } - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) service, err := NewService(store, errFunc) require.NoError(t, err) @@ -721,8 +720,7 @@ func TestSendPaymentV2(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) service, err := NewService(store, errFunc) require.NoError(t, err) @@ -910,8 +908,7 @@ func TestSendToRouteV2(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) service, err := NewService(store, errFunc) require.NoError(t, err) diff --git a/accounts/service_test.go b/accounts/service_test.go index 8834a3a06..9f7927c3a 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -264,7 +264,7 @@ func TestAccountService(t *testing.T) { isRunning := s.IsRunning() return isRunning == false }) - lnd.assertMainErrContains(t, "database not open") + lnd.assertMainErrContains(t, ErrDBClosed.Error()) }, }, { name: "err in invoice err channel", @@ -833,8 +833,7 @@ func TestAccountService(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(tt, err) + store := NewTestDB(t) service, err := NewService(store, errFunc) require.NoError(t, err) diff --git a/accounts/store_test.go b/accounts/store_test.go index 7d00d0f3c..b0196517c 100644 --- a/accounts/store_test.go +++ b/accounts/store_test.go @@ -15,8 +15,7 @@ func TestAccountStore(t *testing.T) { t.Parallel() ctx := context.Background() - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) // Create an account that does not expire. acct1, err := store.NewAccount(ctx, 0, time.Time{}, "foo") @@ -112,10 +111,9 @@ func TestLastInvoiceIndexes(t *testing.T) { t.Parallel() ctx := context.Background() - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) - _, _, err = store.LastIndexes(ctx) + _, _, err := store.LastIndexes(ctx) require.ErrorIs(t, err, ErrNoInvoiceIndexKnown) require.NoError(t, store.StoreLastIndexes(ctx, 7, 99)) diff --git a/accounts/test_kvdb.go b/accounts/test_kvdb.go new file mode 100644 index 000000000..b050d149c --- /dev/null +++ b/accounts/test_kvdb.go @@ -0,0 +1,30 @@ +package accounts + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// ErrDBClosed is an error that is returned when a database operation is +// performed on a closed database. +var ErrDBClosed = errors.New("database not open") + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T) *BoltStore { + return NewTestDBFromPath(t, t.TempDir()) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string) *BoltStore { + store, err := NewBoltStore(dbPath, DBFilename) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, store.db.Close()) + }) + + return store +} From 2012fbaa411a9b5b7e691f59e8014f99a7e77ee9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 4 Jan 2025 13:43:55 +0200 Subject: [PATCH 5/8] accounts: update comments for the BoltStore Store impl --- accounts/store_kvdb.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index e4e22cb7d..c45403087 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -98,12 +98,16 @@ func NewBoltStore(dir, fileName string) (*BoltStore, error) { } // Close closes the underlying bolt DB. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) Close() error { return s.db.Close() } // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, expirationDate time.Time, label string) (*OffChainBalanceAccount, error) { @@ -175,6 +179,8 @@ func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, // UpdateAccount writes an account to the database, overwriting the existing one // if it exists. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) UpdateAccount(_ context.Context, account *OffChainBalanceAccount) error { @@ -228,6 +234,8 @@ func uniqueRandomAccountID(accountBucket kvdb.RBucket) (AccountID, error) { // Account retrieves an account from the bolt DB and un-marshals it. If the // account cannot be found, then ErrAccNotFound is returned. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) Account(_ context.Context, id AccountID) ( *OffChainBalanceAccount, error) { @@ -264,6 +272,8 @@ func (s *BoltStore) Account(_ context.Context, id AccountID) ( } // Accounts retrieves all accounts from the bolt DB and un-marshals them. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, error) { @@ -309,6 +319,8 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, } // RemoveAccount finds an account by its ID and removes it from the DB. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { return s.db.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) @@ -327,6 +339,8 @@ func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { // LastIndexes returns the last invoice add and settle index or // ErrNoInvoiceIndexKnown if no indexes are known yet. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { var ( addValue, settleValue []byte @@ -359,6 +373,8 @@ func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { } // StoreLastIndexes stores the last invoice add and settle index. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) StoreLastIndexes(_ context.Context, addIndex, settleIndex uint64) error { From 9f08daf94892bcae97a025a2a33967479260e7b0 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 4 Jan 2025 13:47:04 +0200 Subject: [PATCH 6/8] accounts: let storeAccount set the LastUpdate timestamp In later commits, we will use this `storeAccount` helper quite often. Instead of needing to remember to update the timestamp outside the call, it make sense to instead update the timestamp within the function. Yes this does mean that sometimes we make no overall changes but do update the timestamp but this is a pretty standard pattern that a "last updated" timestamp is updated at any point that we re-write a record (even if it does not have a net change). --- accounts/store_kvdb.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index c45403087..49992bd3f 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -145,7 +145,6 @@ func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, InitialBalance: balance, CurrentBalance: int64(balance), ExpirationDate: expirationDate, - LastUpdate: time.Now(), Invoices: make(AccountInvoices), Payments: make(AccountPayments), Label: label, @@ -190,7 +189,6 @@ func (s *BoltStore) UpdateAccount(_ context.Context, return ErrAccountBucketNotFound } - account.LastUpdate = time.Now() return storeAccount(bucket, account) }, func() {}) } @@ -200,6 +198,8 @@ func (s *BoltStore) UpdateAccount(_ context.Context, func storeAccount(accountBucket kvdb.RwBucket, account *OffChainBalanceAccount) error { + account.LastUpdate = time.Now() + accountBinary, err := serializeAccount(account) if err != nil { return err From 7fa370ebc76e696b0e572d48e7a6f322e214a78d Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 5 Jan 2025 08:37:03 +0200 Subject: [PATCH 7/8] accounts: return consistent error for duplicate label In preparation for when we have a SQL DB implementation, we want our unit tests to run smoothly against all DB backends and have the same results. To achieve this, we need to turn some errors into global error variables that can be matched against instead. In this commit, we do this for the unique constraint violation of the account label. --- accounts/errors.go | 11 +++++++++++ accounts/store_kvdb.go | 3 ++- accounts/store_test.go | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 accounts/errors.go diff --git a/accounts/errors.go b/accounts/errors.go new file mode 100644 index 000000000..8b3a59afb --- /dev/null +++ b/accounts/errors.go @@ -0,0 +1,11 @@ +package accounts + +import "errors" + +var ( + // ErrLabelAlreadyExists is returned by the CreateAccount method if the + // account label is already used by an existing account. + ErrLabelAlreadyExists = errors.New( + "account label uniqueness constraint violation", + ) +) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index 49992bd3f..84f22f891 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -133,7 +133,8 @@ func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, for _, account := range accounts { if account.Label == label { return nil, fmt.Errorf("an account with the "+ - "label '%s' already exists", label) + "label '%s' already exists: %w", label, + ErrLabelAlreadyExists) } } } diff --git a/accounts/store_test.go b/accounts/store_test.go index b0196517c..b7167c3cf 100644 --- a/accounts/store_test.go +++ b/accounts/store_test.go @@ -29,7 +29,7 @@ func TestAccountStore(t *testing.T) { // Make sure we cannot create a second account with the same label. _, err = store.NewAccount(ctx, 123, time.Time{}, "foo") - require.ErrorContains(t, err, "account with the label 'foo' already") + require.ErrorIs(t, err, ErrLabelAlreadyExists) // Make sure we cannot set a label that looks like an account ID. _, err = store.NewAccount(ctx, 123, time.Time{}, "0011223344556677") From 6f131a14a4a1c060858873f38c87980548474a3b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 5 Jan 2025 08:37:34 +0200 Subject: [PATCH 8/8] accounts+refactor: improve test readability --- accounts/service_test.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/accounts/service_test.go b/accounts/service_test.go index 9f7927c3a..2a28b9174 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -261,8 +261,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains(t, ErrDBClosed.Error()) }, @@ -294,8 +293,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains(t, testErr.Error()) @@ -440,8 +438,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains( t, "not mapped to any account", @@ -483,8 +480,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains(t, testErr.Error())