Skip to content

Commit

Permalink
Merge pull request #934 from ellemouton/sql1Accounts1
Browse files Browse the repository at this point in the history
[sql-1]accounts: preparatory commits for  SQL-izing accounts
  • Loading branch information
ellemouton authored Jan 16, 2025
2 parents 54cf58b + 6f131a1 commit 211865a
Show file tree
Hide file tree
Showing 12 changed files with 301 additions and 188 deletions.
20 changes: 11 additions & 9 deletions accounts/checkers.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func NewAccountChecker(service Service,
}

return nil, service.AssociateInvoice(
acct.ID, hash,
ctx, acct.ID, hash,
)
}, mid.PassThroughErrorHandler,
),
Expand Down Expand Up @@ -615,12 +615,12 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params,
fee := lnrpc.CalculateFeeLimit(limit, sendAmt)
sendAmt += fee

err = service.CheckBalance(acct.ID, sendAmt)
err = service.CheckBalance(ctx, acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %w", err)
}

err = service.AssociatePayment(acct.ID, pHash, sendAmt)
err = service.AssociatePayment(ctx, acct.ID, pHash, sendAmt)
if err != nil {
return fmt.Errorf("error associating payment: %w", err)
}
Expand Down Expand Up @@ -661,11 +661,13 @@ func checkSendResponse(ctx context.Context, service Service,
if status == lnrpc.Payment_FAILED {
service.DeleteValues(reqID)

return nil, service.RemovePayment(hash)
return nil, service.RemovePayment(ctx, hash)
}

// If there is no immediate failure, make sure we track the payment.
err = service.TrackPayment(acct.ID, hash, lnwire.MilliSatoshi(fullAmt))
err = service.TrackPayment(
ctx, acct.ID, hash, lnwire.MilliSatoshi(fullAmt),
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -713,12 +715,12 @@ func checkSendToRoute(ctx context.Context, service Service, paymentHash []byte,
}
sendAmt += fee

err = service.CheckBalance(acct.ID, sendAmt)
err = service.CheckBalance(ctx, acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %w", err)
}

err = service.AssociatePayment(acct.ID, hash, sendAmt)
err = service.AssociatePayment(ctx, acct.ID, hash, sendAmt)
if err != nil {
return fmt.Errorf("error associating payment with hash %s: %w",
hash, err)
Expand Down Expand Up @@ -749,7 +751,7 @@ func erroredPaymentHandler(service Service) mid.ErrorHandler {
"hash: %s and amount: %d", reqVals.PaymentHash,
reqVals.PaymentAmount)

err = service.PaymentErrored(acct.ID, reqVals.PaymentHash)
err = service.PaymentErrored(ctx, acct.ID, reqVals.PaymentHash)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -812,7 +814,7 @@ func sendToRouteHTLCResponseHandler(service Service) func(ctx context.Context,
}

err = service.TrackPayment(
acct.ID, reqValues.PaymentHash,
ctx, acct.ID, reqValues.PaymentHash,
lnwire.MilliSatoshi(totalAmount),
)
if err != nil {
Expand Down
43 changes: 26 additions & 17 deletions accounts/checkers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func newMockService() *mockService {
}
}

func (m *mockService) CheckBalance(_ AccountID,
func (m *mockService) CheckBalance(_ context.Context, _ AccountID,
wantBalance lnwire.MilliSatoshi) error {

if wantBalance > m.acctBalanceMsat {
Expand All @@ -81,24 +81,28 @@ func (m *mockService) CheckBalance(_ AccountID,
return nil
}

func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error {
func (m *mockService) AssociateInvoice(_ context.Context, id AccountID,
hash lntypes.Hash) error {

m.trackedInvoices[hash] = id

return nil
}

func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash,
amt lnwire.MilliSatoshi) error {
func (m *mockService) AssociatePayment(_ context.Context, id AccountID,
paymentHash lntypes.Hash, amt lnwire.MilliSatoshi) error {

return nil
}

func (m *mockService) PaymentErrored(id AccountID, hash lntypes.Hash) error {
func (m *mockService) PaymentErrored(_ context.Context, id AccountID,
hash lntypes.Hash) error {

return nil
}

func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
amt lnwire.MilliSatoshi) error {
func (m *mockService) TrackPayment(_ context.Context, _ AccountID,
hash lntypes.Hash, amt lnwire.MilliSatoshi) error {

m.trackedPayments[hash] = &PaymentEntry{
Status: lnrpc.Payment_UNKNOWN,
Expand All @@ -108,7 +112,9 @@ func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
return nil
}

func (m *mockService) RemovePayment(hash lntypes.Hash) error {
func (m *mockService) RemovePayment(_ context.Context,
hash lntypes.Hash) error {

delete(m.trackedPayments, hash)

return nil
Expand Down Expand Up @@ -517,14 +523,15 @@ func testSendPayment(t *testing.T, uri string) {
errFunc := func(err error) {
lndMock.mainErrChan <- err
}
service, err := NewService(t.TempDir(), errFunc)
store := NewTestDB(t)
service, err := NewService(store, errFunc)
require.NoError(t, err)

err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
acct, err := service.Account(id)
acct, err := service.Account(ctx, id)
require.NoError(t, err)

require.Equal(t, expectedBalance,
Expand All @@ -539,7 +546,7 @@ func testSendPayment(t *testing.T, uri string) {

// Create an account and add it to the context.
acct, err := service.NewAccount(
5000, time.Now().Add(time.Hour), "test",
ctx, 5000, time.Now().Add(time.Hour), "test",
)
require.NoError(t, err)

Expand Down Expand Up @@ -713,14 +720,15 @@ func TestSendPaymentV2(t *testing.T) {
errFunc := func(err error) {
lndMock.mainErrChan <- err
}
service, err := NewService(t.TempDir(), errFunc)
store := NewTestDB(t)
service, err := NewService(store, errFunc)
require.NoError(t, err)

err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
acct, err := service.Account(id)
acct, err := service.Account(ctx, id)
require.NoError(t, err)

require.Equal(t, expectedBalance,
Expand All @@ -735,7 +743,7 @@ func TestSendPaymentV2(t *testing.T) {

// Create an account and add it to the context.
acct, err := service.NewAccount(
5000, time.Now().Add(time.Hour), "test",
ctx, 5000, time.Now().Add(time.Hour), "test",
)
require.NoError(t, err)

Expand Down Expand Up @@ -900,14 +908,15 @@ func TestSendToRouteV2(t *testing.T) {
errFunc := func(err error) {
lndMock.mainErrChan <- err
}
service, err := NewService(t.TempDir(), errFunc)
store := NewTestDB(t)
service, err := NewService(store, errFunc)
require.NoError(t, err)

err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
acct, err := service.Account(id)
acct, err := service.Account(ctx, id)
require.NoError(t, err)

require.Equal(t, expectedBalance,
Expand All @@ -922,7 +931,7 @@ func TestSendToRouteV2(t *testing.T) {

// Create an account and add it to the context.
acct, err := service.NewAccount(
5000, time.Now().Add(time.Hour), "test",
ctx, 5000, time.Now().Add(time.Hour), "test",
)
require.NoError(t, err)

Expand Down
11 changes: 11 additions & 0 deletions accounts/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package accounts

import "errors"

var (
// ErrLabelAlreadyExists is returned by the CreateAccount method if the
// account label is already used by an existing account.
ErrLabelAlreadyExists = errors.New(
"account label uniqueness constraint violation",
)
)
2 changes: 1 addition & 1 deletion accounts/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (s *InterceptorService) Intercept(ctx context.Context,
"macaroon caveat")
}

acct, err := s.Account(*acctID)
acct, err := s.Account(ctx, *acctID)
if err != nil {
return mid.RPCErrString(
req, "error getting account %x: %v", acctID[:], err,
Expand Down
38 changes: 23 additions & 15 deletions accounts/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package accounts

import (
"context"
"encoding/hex"
"errors"
"fmt"
Expand Down Expand Up @@ -201,30 +202,34 @@ var (
type Store interface {
// NewAccount creates a new OffChainBalanceAccount with the given
// balance and a randomly chosen ID.
NewAccount(balance lnwire.MilliSatoshi, expirationDate time.Time,
label string) (*OffChainBalanceAccount, error)
NewAccount(ctx context.Context, balance lnwire.MilliSatoshi,
expirationDate time.Time, label string) (
*OffChainBalanceAccount, error)

// UpdateAccount writes an account to the database, overwriting the
// existing one if it exists.
UpdateAccount(account *OffChainBalanceAccount) error
UpdateAccount(ctx context.Context,
account *OffChainBalanceAccount) error

// Account retrieves an account from the Store and un-marshals it. If
// the account cannot be found, then ErrAccNotFound is returned.
Account(id AccountID) (*OffChainBalanceAccount, error)
Account(ctx context.Context, id AccountID) (*OffChainBalanceAccount,
error)

// Accounts retrieves all accounts from the store and un-marshals them.
Accounts() ([]*OffChainBalanceAccount, error)
Accounts(ctx context.Context) ([]*OffChainBalanceAccount, error)

// RemoveAccount finds an account by its ID and removes it from the¨
// store.
RemoveAccount(id AccountID) error
RemoveAccount(ctx context.Context, id AccountID) error

// LastIndexes returns the last invoice add and settle index or
// ErrNoInvoiceIndexKnown if no indexes are known yet.
LastIndexes() (uint64, uint64, error)
LastIndexes(ctx context.Context) (uint64, uint64, error)

// StoreLastIndexes stores the last invoice add and settle index.
StoreLastIndexes(addIndex, settleIndex uint64) error
StoreLastIndexes(ctx context.Context, addIndex,
settleIndex uint64) error

// Close closes the underlying store.
Close() error
Expand All @@ -234,34 +239,37 @@ type Store interface {
type Service interface {
// CheckBalance ensures an account is valid and has a balance equal to
// or larger than the amount that is required.
CheckBalance(id AccountID, requiredBalance lnwire.MilliSatoshi) error
CheckBalance(ctx context.Context, id AccountID,
requiredBalance lnwire.MilliSatoshi) error

// AssociateInvoice associates a generated invoice with the given
// account, making it possible for the account to be credited in case
// the invoice is paid.
AssociateInvoice(id AccountID, hash lntypes.Hash) error
AssociateInvoice(ctx context.Context, id AccountID,
hash lntypes.Hash) error

// TrackPayment adds a new payment to be tracked to the service. If the
// payment is eventually settled, its amount needs to be debited from
// the given account.
TrackPayment(id AccountID, hash lntypes.Hash,
TrackPayment(ctx context.Context, id AccountID, hash lntypes.Hash,
fullAmt lnwire.MilliSatoshi) error

// RemovePayment removes a failed payment from the service because it no
// longer needs to be tracked. The payment is certain to never succeed,
// so we never need to debit the amount from the account.
RemovePayment(hash lntypes.Hash) error
RemovePayment(ctx context.Context, hash lntypes.Hash) error

// AssociatePayment associates a payment (hash) with the given account,
// ensuring that the payment will be tracked for a user when LiT is
// restarted.
AssociatePayment(id AccountID, paymentHash lntypes.Hash,
fullAmt lnwire.MilliSatoshi) error
AssociatePayment(ctx context.Context, id AccountID,
paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error

// PaymentErrored removes a pending payment from the accounts
// registered payment list. This should only ever be called if we are
// sure that the payment request errored out.
PaymentErrored(id AccountID, hash lntypes.Hash) error
PaymentErrored(ctx context.Context, id AccountID,
hash lntypes.Hash) error

RequestValuesStore
}
Expand Down
Loading

0 comments on commit 211865a

Please sign in to comment.