diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aed7799a..c36cf0ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,18 +10,24 @@ permissions: contents: read jobs: - ci: - name: ci + lint: + name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - - uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6.0.0 + - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 with: go-version: 1.25 - name: lint - uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 + uses: golangci/golangci-lint-action@e7fa5ac41e1cf5b7d48e45e42232ce7ada589601 # v9.1.0 + + test: + name: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: test run: make docker && make docker-test diff --git a/auth/auth.go b/auth/auth.go index 467b65e2..06d52402 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -4,10 +4,12 @@ import ( "crypto/ed25519" "encoding/base64" "encoding/hex" + "errors" "fmt" "net/http" "os" "regexp" + "slices" "strconv" "strings" "time" @@ -46,7 +48,7 @@ func authenticate(tkn string) (string, error) { re := regexp.MustCompile(tokenRE) m := re.FindStringSubmatch(string(base64DecodedBytes)) if m == nil { - return "", fmt.Errorf("invalid token format") + return "", errors.New("invalid token format") } token := Token{TimestampHex: m[1], SignedTimestampHex: m[2], PublicKeyHex: m[3]} @@ -64,7 +66,7 @@ func authenticate(tkn string) (string, error) { return "", fmt.Errorf("error decoding hex string: %w", err) } if !ed25519.Verify(publicKey, timestampBytes, signedTimestamp) { - return "", fmt.Errorf("signature verification failed") + return "", errors.New("signature verification failed") } var timestamp int64 @@ -75,14 +77,12 @@ func authenticate(tkn string) (string, error) { // Verify that this token is within +/- 1 day. if abs(time.Now().UnixMilli()-timestamp) > TokenMaxDuration { - return "", fmt.Errorf("token is expired") + return "", errors.New("token is expired") } blockedIDs := strings.Split(os.Getenv("BLOCKED_CLIENT_IDS"), ",") - for _, id := range blockedIDs { - if token.PublicKeyHex == id { - return "", fmt.Errorf("This client ID is blocked") - } + if slices.Contains(blockedIDs, token.PublicKeyHex) { + return "", errors.New("this client ID is blocked") } return token.PublicKeyHex, nil @@ -98,12 +98,12 @@ func Authorize(r *http.Request) (string, error) { if ok && len(tokens) >= 1 { token = tokens[0] if !strings.HasPrefix(token, bearerPrefix) { - return "", fmt.Errorf("Not a valid token") + return "", errors.New("not a valid token") } token = strings.TrimPrefix(token, bearerPrefix) } if token == "" { - return "", fmt.Errorf("Not a valid token") + return "", errors.New("not a valid token") } // Verify token diff --git a/auth/auth_test.go b/auth/auth_test.go index c753e2ef..0a570df0 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -2,14 +2,16 @@ package auth_test import ( "encoding/base64" + "errors" "fmt" "net/http" "testing" "time" + "github.com/stretchr/testify/suite" + "github.com/brave/go-sync/auth" "github.com/brave/go-sync/auth/authtest" - "github.com/stretchr/testify/suite" ) type AuthTestSuite struct { @@ -20,14 +22,14 @@ func (suite *AuthTestSuite) TestAuthenticate() { // invalid token format id, err := auth.Authenticate(base64.URLEncoding.EncodeToString([]byte("||"))) suite.Require().Error(err, "invalid token format should fail") - suite.Require().Equal("", id, "empty clientID should be returned") + suite.Require().Empty(id, "empty clientID should be returned") // invalid signature _, tokenHex, _, err := authtest.GenerateToken(time.Now().UnixMilli()) suite.Require().NoError(err, "generate token should succeed") id, err = auth.Authenticate(base64.URLEncoding.EncodeToString([]byte("12" + tokenHex))) suite.Require().Error(err, "invalid signature should fail") - suite.Require().Equal("", id) + suite.Require().Empty(id) // valid token tkn, _, expectedID, err := authtest.GenerateToken(time.Now().UnixMilli()) @@ -41,13 +43,13 @@ func (suite *AuthTestSuite) TestAuthenticate() { suite.Require().NoError(err, "generate token should succeed") id, err = auth.Authenticate(tkn) suite.Require().Error(err, "outdated token should failed") - suite.Require().Equal("", id) + suite.Require().Empty(id) tkn, _, _, err = authtest.GenerateToken(time.Now().UnixMilli() + auth.TokenMaxDuration + 100) suite.Require().NoError(err, "generate token should succeed") id, err = auth.Authenticate(tkn) suite.Require().Error(err, "outdated token should failed") - suite.Require().Equal("", id) + suite.Require().Empty(id) } func (suite *AuthTestSuite) TestAuthorize() { @@ -59,8 +61,8 @@ func (suite *AuthTestSuite) TestAuthorize() { outdatedToken, _, _, err := authtest.GenerateToken(time.Now().UnixMilli() - auth.TokenMaxDuration - 1) suite.Require().NoError(err, "generate token should succeed") - invalidTokenErr := fmt.Errorf("Not a valid token") - outdatedErr := fmt.Errorf("error authorizing: %w", fmt.Errorf("token is expired")) + invalidTokenErr := errors.New("not a valid token") + outdatedErr := fmt.Errorf("error authorizing: %w", errors.New("token is expired")) tests := map[string]struct { token string clientID string diff --git a/cache/cache.go b/cache/cache.go index fbe376ff..daeebd4b 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -33,7 +33,7 @@ func GetInterimCountKey(clientID string, countType string) string { return clientID + "#interim_" + countType } -// GetAndClearInterimCount returns the amount of entities inserted in +// GetInterimCount returns the amount of entities inserted in // the DB that were not yet added to the item count func (c *Cache) GetInterimCount(ctx context.Context, clientID string, countType string, clearCache bool) (int, error) { countStr, err := c.Get(ctx, GetInterimCountKey(clientID, countType), clearCache) diff --git a/cache/cache_test.go b/cache/cache_test.go index 3ca4deb5..8d4a70f9 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -4,12 +4,14 @@ import ( "context" "testing" - "github.com/brave/go-sync/cache" "github.com/stretchr/testify/suite" + + "github.com/brave/go-sync/cache" ) type CacheTestSuite struct { suite.Suite + cache *cache.Cache } @@ -21,7 +23,7 @@ func (suite *CacheTestSuite) TestSetTypeMtime() { suite.cache.SetTypeMtime(context.Background(), "id", 123, 12345678) val, err := suite.cache.Get(context.Background(), "id#123", false) suite.Require().NoError(err) - suite.Require().Equal(val, "12345678") + suite.Require().Equal("12345678", val) } func (suite *CacheTestSuite) TestIsTypeMtimeUpdated() { diff --git a/command/command.go b/command/command.go index 85ff541d..8b8b541b 100644 --- a/command/command.go +++ b/command/command.go @@ -3,13 +3,15 @@ package command import ( "context" "encoding/binary" + "errors" "fmt" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/rs/zerolog/log" + "github.com/brave/go-sync/cache" "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/rs/zerolog/log" ) var ( @@ -25,7 +27,6 @@ const ( setSyncPollInterval int32 = 30 nigoriTypeID int32 = 47745 deviceInfoTypeID int = 154522 - maxActiveDevices int = 50 historyCountTypeStr string = "history" normalCountTypeStr string = "normal" ) @@ -54,10 +55,10 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag activeDevices++ } - // Error out when exceeds the limit. - if activeDevices >= maxActiveDevices { + // Error out when device limit has been reached. + if hasReachedDeviceLimit(activeDevices, clientID) { errCode = sync_pb.SyncEnums_THROTTLED - return &errCode, fmt.Errorf("exceed limit of active devices in a chain") + return &errCode, errors.New("exceed limit of active devices in a chain") } } @@ -125,7 +126,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag token, n := binary.Varint(guRsp.NewProgressMarker[i].Token) if n <= 0 { - return nil, fmt.Errorf("Failed at decoding token value %v", token) + return nil, fmt.Errorf("failed at decoding token value %v", token) } // Check cache to short circuit with 0 updates for polling requests. @@ -150,7 +151,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag if isNewClient && *fromProgressMarker.DataTypeId == nigoriTypeID && token == 0 && len(entities) == 0 { errCode = sync_pb.SyncEnums_TRANSIENT_ERROR - return &errCode, fmt.Errorf("nigori root folder entity is not ready yet") + return &errCode, errors.New("nigori root folder entity is not ready yet") } if hasChangesRemaining { @@ -224,7 +225,7 @@ func getInterimItemCounts(cache *cache.Cache, clientID string, clearCache bool) // - existed sync entity will be updated if version is greater than 0. func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, commitRsp *sync_pb.CommitResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { if commitMsg == nil { - return nil, fmt.Errorf("nil commitMsg is received") + return nil, errors.New("nil commitMsg is received") } errCode := sync_pb.SyncEnums_SUCCESS // default value, might be changed later @@ -258,6 +259,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // Map to save commit data type ID & mtime typeMtimeMap := make(map[int]int64) for i, v := range commitMsg.Entries { + var conflict, deleted bool entryRsp := &sync_pb.CommitResponse_EntryResponse{} commitRsp.Entryresponse[i] = entryRsp @@ -305,7 +307,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // Insert all non-history items. For history items, ignore any items above history quoto // and lie to the client about the objects being synced instead of returning OVER_QUOTA // so the client can continue to sync other entities. - conflict, err := db.InsertSyncEntity(entityToCommit) + conflict, err = db.InsertSyncEntity(entityToCommit) if err != nil { log.Error().Err(err).Msg("Insert sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -330,7 +332,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c } } } else { // Update - conflict, deleted, err := db.UpdateSyncEntity(entityToCommit, oldVersion) + conflict, deleted, err = db.UpdateSyncEntity(entityToCommit, oldVersion) if err != nil { log.Error().Err(err).Msg("Update sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -354,7 +356,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c if err != nil { log.Error().Err(err).Msg("Interim count update failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR - return &errCode, fmt.Errorf("Interim count update failed: %w", err) + return &errCode, fmt.Errorf("interim count update failed: %w", err) } typeMtimeMap[*entityToCommit.DataType] = *entityToCommit.Mtime @@ -443,7 +445,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM var err error if pb.MessageContents == nil { - return fmt.Errorf("nil pb.MessageContents received") + return errors.New("nil pb.MessageContents received") } else if *pb.MessageContents == sync_pb.ClientToServerMessage_GET_UPDATES { guRsp := &sync_pb.GetUpdatesResponse{} pbRsp.GetUpdates = guRsp @@ -487,7 +489,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM return fmt.Errorf("error handling ClearServerData request: %w", err) } } else { - return fmt.Errorf("unsupported message type of ClientToServerMessage") + return errors.New("unsupported message type of ClientToServerMessage") } return nil diff --git a/command/command_test.go b/command/command_test.go index 4c9066e4..e646d5c5 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -4,18 +4,20 @@ import ( "context" "encoding/binary" "encoding/json" + "fmt" "sort" "strconv" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/suite" + "github.com/brave/go-sync/cache" "github.com/brave/go-sync/command" "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/datastore/datastoretest" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/stretchr/testify/suite" ) const ( @@ -27,6 +29,7 @@ const ( type CommandTestSuite struct { suite.Suite + dynamo *datastore.Dynamo cache *cache.Cache } @@ -96,6 +99,15 @@ func getBookmarkSpecifics() *sync_pb.EntitySpecifics { } } +func getDeviceInfoSpecifics() *sync_pb.EntitySpecifics { + deviceInfoEntitySpecifics := &sync_pb.EntitySpecifics_DeviceInfo{ + DeviceInfo: &sync_pb.DeviceInfoSpecifics{}, + } + return &sync_pb.EntitySpecifics{ + SpecificsVariant: deviceInfoEntitySpecifics, + } +} + func getCommitEntity(id string, version int64, deleted bool, specifics *sync_pb.EntitySpecifics) *sync_pb.SyncEntity { return &sync_pb.SyncEntity{ IdString: aws.String(id), @@ -128,7 +140,7 @@ func getClientToServerCommitMsg(entries []*sync_pb.SyncEntity) *sync_pb.ClientTo func getMarker(suite *CommandTestSuite, tokens []int64) []*sync_pb.DataTypeProgressMarker { types := []int32{nigoriType, bookmarkType} // hard-coded types used in tests. - suite.Assert().Equal(len(types), len(tokens)) + suite.Len(tokens, len(types)) marker := []*sync_pb.DataTypeProgressMarker{} for i, token := range tokens { tokenBytes := make([]byte, binary.MaxVarintLen64) @@ -156,26 +168,26 @@ func getClientToServerGUMsg(marker []*sync_pb.DataTypeProgressMarker, func getTokensFromNewMarker(suite *CommandTestSuite, newMarker []*sync_pb.DataTypeProgressMarker) (int64, int64) { nigoriToken, n := binary.Varint(newMarker[0].Token) - suite.Assert().Greater(n, 0) + suite.Positive(n) bookmarkToken, n := binary.Varint(newMarker[1].Token) - suite.Assert().Greater(n, 0) + suite.Positive(n) return nigoriToken, bookmarkToken } func assertCommonResponse(suite *CommandTestSuite, rsp *sync_pb.ClientToServerResponse, isCommit bool) { - suite.Assert().Equal(sync_pb.SyncEnums_SUCCESS, *rsp.ErrorCode, "errorCode should match") - suite.Assert().Equal(getClientCommand(), rsp.ClientCommand, "ClientCommand should match") - suite.Assert().Equal(command.StoreBirthday, *rsp.StoreBirthday, "Birthday should match") + suite.Equal(sync_pb.SyncEnums_SUCCESS, *rsp.ErrorCode, "errorCode should match") + suite.Equal(getClientCommand(), rsp.ClientCommand, "ClientCommand should match") + suite.Equal(command.StoreBirthday, *rsp.StoreBirthday, "Birthday should match") if isCommit { - suite.Assert().NotNil(rsp.Commit) + suite.NotNil(rsp.Commit) } else { - suite.Assert().NotNil(rsp.GetUpdates) + suite.NotNil(rsp.GetUpdates) } } func assertGetUpdatesResponse(suite *CommandTestSuite, rsp *sync_pb.GetUpdatesResponse, newMarker *[]*sync_pb.DataTypeProgressMarker, expectedPBSyncAttrs []*PBSyncAttrs, - expectedChangesRemaining int64) { + expectedChangesRemaining int64) { //nolint:unparam PBSyncAttrs := []*PBSyncAttrs{} for _, entity := range rsp.Entries { // Update tokens in the expected NewProgressMarker @@ -186,7 +198,7 @@ func assertGetUpdatesResponse(suite *CommandTestSuite, rsp *sync_pb.GetUpdatesRe tokenPtr = &(*newMarker)[1].Token } token, n := binary.Varint(*tokenPtr) - suite.Assert().Greater(n, 0) + suite.Positive(n) if token < *entity.Mtime { binary.PutVarint(*tokenPtr, *entity.Mtime) } @@ -204,10 +216,10 @@ func assertGetUpdatesResponse(suite *CommandTestSuite, rsp *sync_pb.GetUpdatesRe suite.Require().NoError(err, "json.Marshal should succeed") s2, err := json.Marshal(PBSyncAttrs) suite.Require().NoError(err, "json.Marshal should succeed") - suite.Assert().Equal(s1, s2) + suite.Equal(s1, s2) - suite.Assert().Equal(*newMarker, rsp.NewProgressMarker) - suite.Assert().Equal(expectedChangesRemaining, *rsp.ChangesRemaining) + suite.Equal(*newMarker, rsp.NewProgressMarker) + suite.Equal(expectedChangesRemaining, *rsp.ChangesRemaining) } func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { @@ -224,13 +236,13 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 2) commitSuccess := sync_pb.CommitResponse_SUCCESS serverIDs := []string{} commitVersions := []int64{} for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) - suite.Assert().Equal(*entryRsp.Mtime, *entryRsp.Version) + suite.Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(*entryRsp.Mtime, *entryRsp.Version) serverIDs = append(serverIDs, *entryRsp.IdString) commitVersions = append(commitVersions, *entryRsp.Version) } @@ -269,12 +281,12 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 4) serverIDs = []string{} commitVersions = []int64{} for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) - suite.Assert().Equal(*entryRsp.Mtime, *entryRsp.Version) + suite.Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(*entryRsp.Mtime, *entryRsp.Version) serverIDs = append(serverIDs, *entryRsp.IdString) commitVersions = append(commitVersions, *entryRsp.Version) } @@ -315,10 +327,10 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 2) commitConflict := sync_pb.CommitResponse_CONFLICT for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitConflict, *entryRsp.ResponseType) + suite.Equal(commitConflict, *entryRsp.ResponseType) } // GetUpdates again with previous returned tokens should return 0 updates. @@ -373,7 +385,54 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_NewClient() { // Check dummy encryption keys only for NEW_CLIENT case. expectedEncryptionKeys := make([][]byte, 1) expectedEncryptionKeys[0] = []byte("1234") - suite.Assert().Equal(expectedEncryptionKeys, rsp.GetUpdates.EncryptionKeys) + suite.Equal(expectedEncryptionKeys, rsp.GetUpdates.EncryptionKeys) +} + +func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceeded() { + highDeviceLimitClientID := "high_device_limit_client_id" + command.LoadHighDeviceLimitClientIDs(fmt.Sprintf("randomid,%s,anotherrandomid", highDeviceLimitClientID)) + + testCases := []struct { + clientID string + expectedDeviceLimit int + }{ + {clientID: "client_id_1", expectedDeviceLimit: 50}, + {clientID: highDeviceLimitClientID, expectedDeviceLimit: 100}, + } + + for _, testCase := range testCases { + // Simulate devices calling GetUpdates with NEW_CLIENT origin up to the expected device limit. + marker := getMarker(suite, []int64{0, 0}) + msg := getClientToServerGUMsg( + marker, sync_pb.SyncEnums_NEW_CLIENT, true, nil) + for i := 1; i <= testCase.expectedDeviceLimit; i++ { + rsp := &sync_pb.ClientToServerResponse{} + + suite.Require().NoError( + command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + "HandleClientToServerMessage should succeed for device %d", i) + suite.Equal(sync_pb.SyncEnums_SUCCESS, *rsp.ErrorCode, "device %d should succeed", i) + suite.NotNil(rsp.GetUpdates, "device %d should have GetUpdates response", i) + + // Commit a device info entity after GetUpdates + deviceEntry := getCommitEntity(fmt.Sprintf("device_%d", i), 0, false, getDeviceInfoSpecifics()) + commitMsg := getClientToServerCommitMsg([]*sync_pb.SyncEntity{deviceEntry}) + commitRsp := &sync_pb.ClientToServerResponse{} + suite.Require().NoError( + command.HandleClientToServerMessage(suite.cache, commitMsg, commitRsp, suite.dynamo, testCase.clientID), + "Commit device info should succeed for device %d", i) + suite.Equal(sync_pb.SyncEnums_SUCCESS, *commitRsp.ErrorCode, "Commit device info should succeed for device %d", i) + } + + // should get THROTTLED error when device limit is exceeded + rsp := &sync_pb.ClientToServerResponse{} + suite.Require().NoError( + command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + "HandleClientToServerMessage should succeed") + suite.Equal(sync_pb.SyncEnums_THROTTLED, *rsp.ErrorCode, "errorCode should be THROTTLED") + suite.Require().NotNil(rsp.ErrorMessage, "error message should be present") + suite.Contains(*rsp.ErrorMessage, "exceed limit of active devices") + } } func (suite *CommandTestSuite) TestHandleClientToServerMessage_GUBatchSize() { @@ -392,11 +451,11 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_GUBatchSize() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 4) commitSuccess := sync_pb.CommitResponse_SUCCESS for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) - suite.Assert().Equal(*entryRsp.Mtime, *entryRsp.Version) + suite.Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(*entryRsp.Mtime, *entryRsp.Version) } } @@ -416,13 +475,13 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 2) commitSuccess := sync_pb.CommitResponse_SUCCESS serverIDs := []string{} commitVersions := []int64{} for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) - suite.Assert().Equal(*entryRsp.Mtime, *entryRsp.Version) + suite.Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(*entryRsp.Mtime, *entryRsp.Version) serverIDs = append(serverIDs, *entryRsp.IdString) commitVersions = append(commitVersions, *entryRsp.Version) } @@ -441,13 +500,13 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 4) overQuota := sync_pb.CommitResponse_OVER_QUOTA expectedEntryRsp := []sync_pb.CommitResponse_ResponseType{commitSuccess, commitSuccess, overQuota, overQuota} expectedVersion := []*int64{rsp.Commit.Entryresponse[0].Mtime, rsp.Commit.Entryresponse[1].Mtime, nil, nil} for i, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(expectedEntryRsp[i], *entryRsp.ResponseType) - suite.Assert().Equal(expectedVersion[i], entryRsp.Version) + suite.Equal(expectedEntryRsp[i], *entryRsp.ResponseType) + suite.Equal(expectedVersion[i], entryRsp.Version) } // Commit 2 items again when quota is already exceed should get two OVER_QUOTA @@ -462,9 +521,9 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 2) for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(overQuota, *entryRsp.ResponseType) + suite.Equal(overQuota, *entryRsp.ResponseType) } // Commit updates to delete two previous inserted items. @@ -479,10 +538,10 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 2) for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) - suite.Assert().Equal(*entryRsp.Mtime, *entryRsp.Version) + suite.Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(*entryRsp.Mtime, *entryRsp.Version) } // Commit 4 items should have two success and two OVER_QUOTA. @@ -499,14 +558,14 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 4) expectedVersion = []*int64{rsp.Commit.Entryresponse[0].Mtime, rsp.Commit.Entryresponse[1].Mtime, nil, nil} for i, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(expectedEntryRsp[i], *entryRsp.ResponseType) + suite.Equal(expectedEntryRsp[i], *entryRsp.ResponseType) if *entryRsp.ResponseType == commitSuccess { - suite.Assert().Equal(*expectedVersion[i], *entryRsp.Version) + suite.Equal(*expectedVersion[i], *entryRsp.Version) } else { - suite.Assert().Equal(expectedVersion[i], entryRsp.Version) + suite.Equal(expectedVersion[i], entryRsp.Version) } } @@ -521,10 +580,10 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(1, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 1) commitSuccess := sync_pb.CommitResponse_SUCCESS for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(commitSuccess, *entryRsp.ResponseType) } // Commit parents with its child bookmarks in one commit request. @@ -550,9 +609,9 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(6, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 6) for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(commitSuccess, *entryRsp.ResponseType) } // Get updates to check if child's parent ID is replaced with the server @@ -565,26 +624,26 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) - suite.Require().Equal(6, len(rsp.GetUpdates.Entries)) - for i := 0; i < len(rsp.GetUpdates.Entries); i++ { + suite.Require().Len(rsp.GetUpdates.Entries, 6) + for i := range rsp.GetUpdates.Entries { if i != len(rsp.GetUpdates.Entries)-1 { - suite.Assert().Equal(rsp.GetUpdates.Entries[i].OriginatorClientItemId, entries[i].IdString) + suite.Equal(rsp.GetUpdates.Entries[i].OriginatorClientItemId, entries[i].IdString) } else { - suite.Assert().Equal(rsp.GetUpdates.Entries[i].OriginatorClientItemId, child0.IdString) + suite.Equal(rsp.GetUpdates.Entries[i].OriginatorClientItemId, child0.IdString) } - suite.Assert().NotNil(rsp.GetUpdates.Entries[i].IdString) + suite.NotNil(rsp.GetUpdates.Entries[i].IdString) } - suite.Assert().Equal(rsp.GetUpdates.Entries[1].ParentIdString, rsp.GetUpdates.Entries[0].IdString) - suite.Assert().Equal(rsp.GetUpdates.Entries[3].ParentIdString, rsp.GetUpdates.Entries[0].IdString) - suite.Assert().Equal(rsp.GetUpdates.Entries[4].ParentIdString, rsp.GetUpdates.Entries[2].IdString) - suite.Assert().Equal(rsp.GetUpdates.Entries[5].ParentIdString, rsp.GetUpdates.Entries[0].IdString) + suite.Equal(rsp.GetUpdates.Entries[1].ParentIdString, rsp.GetUpdates.Entries[0].IdString) + suite.Equal(rsp.GetUpdates.Entries[3].ParentIdString, rsp.GetUpdates.Entries[0].IdString) + suite.Equal(rsp.GetUpdates.Entries[4].ParentIdString, rsp.GetUpdates.Entries[2].IdString) + suite.Equal(rsp.GetUpdates.Entries[5].ParentIdString, rsp.GetUpdates.Entries[0].IdString) } func assertTypeMtimeCacheValue(suite *CommandTestSuite, key string, mtime int64, errMsg string) { val, err := suite.cache.Get(context.Background(), key, false) suite.Require().NoError(err, "cache.Get should succeed") - suite.Assert().Equal(val, strconv.FormatInt(mtime, 10), errMsg) + suite.Equal(val, strconv.FormatInt(mtime, 10), errMsg) } func insertSyncEntitiesWithoutUpdateCache( @@ -619,12 +678,12 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(3, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 3) commitSuccess := sync_pb.CommitResponse_SUCCESS var latestBookmarkMtime int64 var latestNigoriMtime int64 for i, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(commitSuccess, *entryRsp.ResponseType) if i < 2 { latestBookmarkMtime = *entryRsp.Mtime } @@ -660,7 +719,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) - suite.Assert().Equal(0, len(rsp.GetUpdates.Entries)) + suite.Empty(rsp.GetUpdates.Entries) assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), latestBookmarkMtime, "cache is not updated when short circuited") assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(nigoriType)), @@ -683,9 +742,9 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(1, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 1) entryRsp := rsp.Commit.Entryresponse[0] - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) + suite.Equal(commitSuccess, *entryRsp.ResponseType) latestBookmarkMtime = *entryRsp.Mtime assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), @@ -701,9 +760,9 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) - suite.Assert().Equal(2, len(rsp.GetUpdates.Entries)) - suite.Assert().Equal(latestNigoriMtime, *rsp.GetUpdates.Entries[0].Mtime) - suite.Assert().Equal(latestBookmarkMtime, *rsp.GetUpdates.Entries[1].Mtime) + suite.Len(rsp.GetUpdates.Entries, 2) + suite.Equal(latestNigoriMtime, *rsp.GetUpdates.Entries[0].Mtime) + suite.Equal(latestBookmarkMtime, *rsp.GetUpdates.Entries[1].Mtime) assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), latestBookmarkMtime, "Cached token should be equal to latest mtime") assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(nigoriType)), @@ -722,9 +781,9 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Sk command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(1, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 1) commitSuccess := sync_pb.CommitResponse_SUCCESS - suite.Assert().Equal(commitSuccess, *rsp.Commit.Entryresponse[0].ResponseType) + suite.Equal(commitSuccess, *rsp.Commit.Entryresponse[0].ResponseType) latestBookmarkMtime := *rsp.Commit.Entryresponse[0].Mtime assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), latestBookmarkMtime, @@ -749,7 +808,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Sk command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) - suite.Require().Equal(1, len(rsp.GetUpdates.Entries)) + suite.Require().Len(rsp.GetUpdates.Entries, 1) suite.Require().Equal(dbEntries[0].Mtime, rsp.GetUpdates.Entries[0].Mtime) assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), *dbEntries[0].Mtime, "Successful commit should update the cache") @@ -768,12 +827,12 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) - suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) + suite.Len(rsp.Commit.Entryresponse, 2) commitSuccess := sync_pb.CommitResponse_SUCCESS var latestBookmarkMtime int64 for _, entryRsp := range rsp.Commit.Entryresponse { - suite.Assert().Equal(commitSuccess, *entryRsp.ResponseType) - suite.Assert().NotEqual(latestBookmarkMtime, *entryRsp.Mtime) + suite.Equal(commitSuccess, *entryRsp.ResponseType) + suite.NotEqual(latestBookmarkMtime, *entryRsp.Mtime) latestBookmarkMtime = *entryRsp.Mtime } assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), @@ -791,7 +850,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) - suite.Require().Equal(2, len(rsp.GetUpdates.Entries)) + suite.Require().Len(rsp.GetUpdates.Entries, 2) suite.Require().Equal(int64(0), *rsp.GetUpdates.ChangesRemaining) mtime := *rsp.GetUpdates.Entries[0].Mtime assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), @@ -808,7 +867,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) - suite.Require().Equal(1, len(rsp.GetUpdates.Entries)) + suite.Require().Len(rsp.GetUpdates.Entries, 1) suite.Require().Equal(int64(0), *rsp.GetUpdates.ChangesRemaining) assertTypeMtimeCacheValue(suite, clientID+"#"+strconv.Itoa(int(bookmarkType)), latestBookmarkMtime, diff --git a/command/device_limit.go b/command/device_limit.go new file mode 100644 index 00000000..15095054 --- /dev/null +++ b/command/device_limit.go @@ -0,0 +1,38 @@ +package command + +import ( + "os" + "strings" +) + +const ( + maxActiveDevices int = 50 + highMaxActiveDevices int = 100 +) + +var ( + highDeviceLimitClientIDs map[string]bool +) + +func init() { + clientIDsEnv := os.Getenv("HIGH_DEVICE_LIMIT_CLIENT_IDS") + LoadHighDeviceLimitClientIDs(clientIDsEnv) +} + +func LoadHighDeviceLimitClientIDs(clientIDList string) { + highDeviceLimitClientIDs = make(map[string]bool) + if clientIDList != "" { + ids := strings.Split(clientIDList, ",") + for _, id := range ids { + highDeviceLimitClientIDs[strings.ToLower(strings.TrimSpace(id))] = true + } + } +} + +func hasReachedDeviceLimit(activeDevices int, clientID string) bool { + limit := maxActiveDevices + if highDeviceLimitClientIDs[strings.ToLower(clientID)] { + limit = highMaxActiveDevices + } + return activeDevices >= limit +} diff --git a/command/server_defined_unique_entity.go b/command/server_defined_unique_entity.go index 31dce609..8c81f310 100644 --- a/command/server_defined_unique_entity.go +++ b/command/server_defined_unique_entity.go @@ -4,10 +4,11 @@ import ( "fmt" "time" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/google/uuid" + "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/satori/go.uuid" ) const ( @@ -28,7 +29,7 @@ func createServerDefinedUniqueEntity(name string, serverDefinedTag string, clien deleted := false folder := true version := int64(1) - idString := uuid.NewV4().String() + idString := uuid.New().String() pbEntity := &sync_pb.SyncEntity{ Ctime: &now, Mtime: &now, Deleted: &deleted, Folder: &folder, diff --git a/command/server_defined_unique_entity_test.go b/command/server_defined_unique_entity_test.go index 94832250..ea5119c3 100644 --- a/command/server_defined_unique_entity_test.go +++ b/command/server_defined_unique_entity_test.go @@ -4,15 +4,17 @@ import ( "sort" "testing" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/suite" + "github.com/brave/go-sync/command" "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/datastore/datastoretest" - "github.com/stretchr/testify/suite" ) type ServerDefinedUniqueEntityTestSuite struct { suite.Suite + dynamo *datastore.Dynamo } @@ -107,9 +109,9 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEn // Check that Ctime and Mtime have been set, reset to zero value for subsequent // tests - for i := 0; i < len(tagItems); i++ { - suite.Assert().NotNil(tagItems[i].Ctime) - suite.Assert().NotNil(tagItems[i].Mtime) + for i := range tagItems { + suite.NotNil(tagItems[i].Ctime) + suite.NotNil(tagItems[i].Mtime) tagItems[i].Ctime = nil tagItems[i].Mtime = nil @@ -117,7 +119,7 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEn sort.Sort(datastore.TagItemByClientIDID(tagItems)) sort.Sort(datastore.TagItemByClientIDID(expectedTagItems)) - suite.Assert().Equal(tagItems, expectedTagItems) + suite.Equal(expectedTagItems, tagItems) syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") @@ -133,7 +135,7 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEn break } } - suite.Assert().NotEqual(bookmarksRootID, "", "Cannot find ID of bookmarks root folder") + suite.NotEmpty(bookmarksRootID, "Cannot find ID of bookmarks root folder") // For each item returned by ScanSyncEntities, make sure it is in the map and // its value is matched, then remove it from the map. @@ -148,11 +150,11 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEn Folder: item.Folder, } - suite.Assert().NotNil(item.ServerDefinedUniqueTag) - suite.Assert().Equal(syncAttrs, *expectedSyncAttrsMap[*item.ServerDefinedUniqueTag]) + suite.NotNil(item.ServerDefinedUniqueTag) + suite.Equal(syncAttrs, *expectedSyncAttrsMap[*item.ServerDefinedUniqueTag]) delete(expectedSyncAttrsMap, *item.ServerDefinedUniqueTag) } - suite.Assert().Equal(0, len(expectedSyncAttrsMap)) + suite.Empty(expectedSyncAttrsMap) suite.Require().NoError( command.InsertServerDefinedUniqueEntities(suite.dynamo, "client2"), diff --git a/controller/controller.go b/controller/controller.go index 7566072c..382d17e7 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -3,20 +3,20 @@ package controller import ( "compress/gzip" "io" - "io/ioutil" "net/http" "github.com/brave-intl/bat-go/libs/closers" "github.com/brave-intl/bat-go/libs/middleware" + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" + "google.golang.org/protobuf/proto" + "github.com/brave/go-sync/cache" "github.com/brave/go-sync/command" - syncContext "github.com/brave/go-sync/context" "github.com/brave/go-sync/datastore" syncMiddleware "github.com/brave/go-sync/middleware" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/go-chi/chi/v5" - "github.com/rs/zerolog/log" - "google.golang.org/protobuf/proto" + syncContext "github.com/brave/go-sync/synccontext" ) const ( @@ -55,7 +55,7 @@ func Command(cache *cache.Cache, db datastore.Datastore) http.HandlerFunc { reader = gr } - msg, err := ioutil.ReadAll(io.LimitReader(reader, payloadLimit10MB)) + msg, err := io.ReadAll(io.LimitReader(reader, payloadLimit10MB)) if err != nil { log.Error().Err(err).Msg("Read request body failed") http.Error(w, "Read request body error", http.StatusInternalServerError) diff --git a/controller/controller_test.go b/controller/controller_test.go index fe818f33..16ddd97f 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -9,20 +9,22 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" + "github.com/brave/go-sync/auth/authtest" "github.com/brave/go-sync/cache" - syncContext "github.com/brave/go-sync/context" "github.com/brave/go-sync/controller" "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/datastore/datastoretest" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/stretchr/testify/suite" - "google.golang.org/protobuf/proto" + syncContext "github.com/brave/go-sync/synccontext" ) type ControllerTestSuite struct { suite.Suite + dynamo *datastore.Dynamo cache *cache.Cache } diff --git a/datastore/datastoretest/dynamo.go b/datastore/datastoretest/dynamo.go index 2c831ab0..7cbbdc70 100644 --- a/datastore/datastoretest/dynamo.go +++ b/datastore/datastoretest/dynamo.go @@ -1,44 +1,49 @@ package datastoretest import ( + "context" "encoding/json" + "errors" "fmt" - "io/ioutil" + "os" "path/filepath" "runtime" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/expression" "github.com/brave/go-sync/datastore" ) // DeleteTable deletes datastore.Table in dynamoDB. func DeleteTable(dynamo *datastore.Dynamo) error { - _, err := dynamo.DeleteTable( + _, err := dynamo.DeleteTable(context.Background(), &dynamodb.DeleteTableInput{TableName: aws.String(datastore.Table)}) if err != nil { - if aerr, ok := err.(awserr.Error); ok { + var notFoundException *types.ResourceNotFoundException + if errors.As(err, ¬FoundException) { // Return as successful if the table is not existed. - if aerr.Code() == dynamodb.ErrCodeResourceNotFoundException { - return nil - } - } else { - return fmt.Errorf("error deleting table: %w", err) + return nil } + return fmt.Errorf("error deleting table: %w", err) } - return dynamo.WaitUntilTableNotExists( - &dynamodb.DescribeTableInput{TableName: aws.String(datastore.Table)}) + // Wait for table to be deleted using waiter + waiter := dynamodb.NewTableNotExistsWaiter(dynamo) + return waiter.Wait(context.Background(), + &dynamodb.DescribeTableInput{TableName: aws.String(datastore.Table)}, + 5*time.Minute) } // CreateTable creates datastore.Table in dynamoDB. func CreateTable(dynamo *datastore.Dynamo) error { _, b, _, _ := runtime.Caller(0) root := filepath.Join(filepath.Dir(b), "../../") - raw, err := ioutil.ReadFile(filepath.Join(root, "schema/dynamodb/table.json")) + raw, err := os.ReadFile(filepath.Join(root, "schema/dynamodb/table.json")) if err != nil { return fmt.Errorf("error reading table.json: %w", err) } @@ -50,13 +55,16 @@ func CreateTable(dynamo *datastore.Dynamo) error { } input.TableName = aws.String(datastore.Table) - _, err = dynamo.CreateTable(&input) + _, err = dynamo.CreateTable(context.Background(), &input) if err != nil { return fmt.Errorf("error creating table: %w", err) } - return dynamo.WaitUntilTableExists( - &dynamodb.DescribeTableInput{TableName: aws.String(datastore.Table)}) + // Wait for table to be active using waiter + waiter := dynamodb.NewTableExistsWaiter(dynamo) + return waiter.Wait(context.Background(), + &dynamodb.DescribeTableInput{TableName: aws.String(datastore.Table)}, + 5*time.Minute) } // ResetTable deletes and creates datastore.Table in dynamoDB. @@ -81,12 +89,12 @@ func ScanSyncEntities(dynamo *datastore.Dynamo) ([]datastore.SyncEntity, error) FilterExpression: expr.Filter(), TableName: aws.String(datastore.Table), } - out, err := dynamo.Scan(input) + out, err := dynamo.Scan(context.Background(), input) if err != nil { return nil, fmt.Errorf("error doing scan for sync entities: %w", err) } syncItems := []datastore.SyncEntity{} - err = dynamodbattribute.UnmarshalListOfMaps(out.Items, &syncItems) + err = attributevalue.UnmarshalListOfMaps(out.Items, &syncItems) if err != nil { return nil, fmt.Errorf("error unmarshalling sync entitites: %w", err) } @@ -110,12 +118,12 @@ func ScanTagItems(dynamo *datastore.Dynamo) ([]datastore.ServerClientUniqueTagIt FilterExpression: expr.Filter(), TableName: aws.String(datastore.Table), } - out, err := dynamo.Scan(input) + out, err := dynamo.Scan(context.Background(), input) if err != nil { return nil, fmt.Errorf("error doing scan for tag items: %w", err) } tagItems := []datastore.ServerClientUniqueTagItem{} - err = dynamodbattribute.UnmarshalListOfMaps(out.Items, &tagItems) + err = attributevalue.UnmarshalListOfMaps(out.Items, &tagItems) if err != nil { return nil, fmt.Errorf("error unmarshalling tag items: %w", err) } @@ -138,12 +146,12 @@ func ScanClientItemCounts(dynamo *datastore.Dynamo) ([]datastore.ClientItemCount FilterExpression: expr.Filter(), TableName: aws.String(datastore.Table), } - out, err := dynamo.Scan(input) + out, err := dynamo.Scan(context.Background(), input) if err != nil { return nil, fmt.Errorf("error doing scan for item counts: %w", err) } clientItemCounts := []datastore.ClientItemCounts{} - err = dynamodbattribute.UnmarshalListOfMaps(out.Items, &clientItemCounts) + err = attributevalue.UnmarshalListOfMaps(out.Items, &clientItemCounts) if err != nil { return nil, fmt.Errorf("error unmarshalling item counts: %w", err) } diff --git a/datastore/datastoretest/mock_datastore.go b/datastore/datastoretest/mock_datastore.go index 09665d40..5cda4da5 100644 --- a/datastore/datastoretest/mock_datastore.go +++ b/datastore/datastoretest/mock_datastore.go @@ -1,8 +1,9 @@ package datastoretest import ( - "github.com/brave/go-sync/datastore" "github.com/stretchr/testify/mock" + + "github.com/brave/go-sync/datastore" ) // MockDatastore is used to mock datastorein tests diff --git a/datastore/dynamo.go b/datastore/dynamo.go index 27ad05a1..c2ab0eaf 100644 --- a/datastore/dynamo.go +++ b/datastore/dynamo.go @@ -1,14 +1,15 @@ package datastore import ( + "context" "fmt" "net/http" "os" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" ) const ( @@ -36,7 +37,7 @@ type PrimaryKey struct { // Dynamo is a Datastore wrapper around a dynamoDB. type Dynamo struct { - *dynamodb.DynamoDB + *dynamodb.Client } // NewDynamo returns a dynamoDB client to be used. @@ -49,13 +50,24 @@ func NewDynamo() (*Dynamo, error) { }, } - awsConfig := aws.NewConfig().WithRegion(os.Getenv("AWS_REGION")).WithEndpoint(os.Getenv("AWS_ENDPOINT")).WithHTTPClient(httpClient) - sess, err := session.NewSession(awsConfig) + ctx := context.Background() + // Load default AWS configuration + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(os.Getenv("AWS_REGION")), + config.WithHTTPClient(httpClient), + ) if err != nil { - return nil, fmt.Errorf("error creating new AWS session: %w", err) + return nil, fmt.Errorf("error loading AWS config: %w", err) } - db := dynamodb.New(sess) + // Create DynamoDB client with optional endpoint override + endpoint := os.Getenv("AWS_ENDPOINT") + db := dynamodb.NewFromConfig(cfg, func(o *dynamodb.Options) { + if endpoint != "" { + o.BaseEndpoint = aws.String(endpoint) + } + }) + return &Dynamo{db}, nil } diff --git a/datastore/item_count.go b/datastore/item_count.go index 1cda9903..b24d3f24 100644 --- a/datastore/item_count.go +++ b/datastore/item_count.go @@ -1,13 +1,15 @@ package datastore import ( + "context" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) const ( @@ -65,23 +67,22 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou if err != nil { return fmt.Errorf("error building history item count query: %w", err) } - selectCount := dynamodb.SelectCount historyCountInput := &dynamodb.QueryInput{ ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), KeyConditionExpression: expr.KeyCondition(), FilterExpression: expr.Filter(), TableName: aws.String(Table), - Select: &selectCount, + Select: types.SelectCount, } - out, err := dynamo.Query(historyCountInput) + out, err := dynamo.Query(context.TODO(), historyCountInput) if err != nil { return fmt.Errorf("error querying history item count: %w", err) } counts.HistoryItemCountPeriod1 = 0 counts.HistoryItemCountPeriod2 = 0 counts.HistoryItemCountPeriod3 = 0 - counts.HistoryItemCountPeriod4 = int(*out.Count) + counts.HistoryItemCountPeriod4 = int(out.Count) filterCond = expression.And( expression.AttributeExists(expression.Name(dataTypeAttrName)), expression.Name(dataTypeAttrName).NotEqual(expression.Value(HistoryTypeID)), @@ -98,13 +99,13 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou KeyConditionExpression: expr.KeyCondition(), FilterExpression: expr.Filter(), TableName: aws.String(Table), - Select: &selectCount, + Select: types.SelectCount, } - out, err = dynamo.Query(normalCountInput) + out, err = dynamo.Query(context.TODO(), normalCountInput) if err != nil { return fmt.Errorf("error querying history item count: %w", err) } - counts.ItemCount = int(*out.Count) + counts.ItemCount = int(out.Count) } counts.LastPeriodChangeTime = now counts.Version = CurrentCountVersion @@ -112,7 +113,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou timeSinceLastChange := now - counts.LastPeriodChangeTime if timeSinceLastChange >= periodDurationSecs { changeCount := int(timeSinceLastChange / periodDurationSecs) - for i := 0; i < changeCount; i++ { + for range changeCount { // The records from "period 1"/the earliest period // will be purged from the count, since they will be deleted via DDB TTL counts.HistoryItemCountPeriod1 = counts.HistoryItemCountPeriod2 @@ -130,7 +131,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou // a given client. func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, error) { primaryKey := PrimaryKey{ClientID: clientID, ID: clientID} - key, err := dynamodbattribute.MarshalMap(primaryKey) + key, err := attributevalue.MarshalMap(primaryKey) if err != nil { return nil, fmt.Errorf("error marshalling primary key to get item-count item: %w", err) } @@ -140,13 +141,13 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er TableName: aws.String(Table), } - out, err := dynamo.GetItem(input) + out, err := dynamo.GetItem(context.TODO(), input) if err != nil { return nil, fmt.Errorf("error getting an item-count item: %w", err) } clientItemCounts := &ClientItemCounts{} - err = dynamodbattribute.UnmarshalMap(out.Item, clientItemCounts) + err = attributevalue.UnmarshalMap(out.Item, clientItemCounts) if err != nil { return nil, fmt.Errorf("error unmarshalling item-count item: %w", err) } @@ -169,7 +170,7 @@ func (dynamo *Dynamo) UpdateClientItemCount(counts *ClientItemCounts, newNormalI counts.HistoryItemCountPeriod4 += newHistoryItemCount counts.ItemCount += newNormalItemCount - item, err := dynamodbattribute.MarshalMap(*counts) + item, err := attributevalue.MarshalMap(*counts) if err != nil { return fmt.Errorf("error marshalling item counts: %w", err) } @@ -179,7 +180,7 @@ func (dynamo *Dynamo) UpdateClientItemCount(counts *ClientItemCounts, newNormalI TableName: aws.String(Table), } - _, err = dynamo.PutItem(input) + _, err = dynamo.PutItem(context.TODO(), input) if err != nil { return fmt.Errorf("error updating item-count item in dynamoDB: %w", err) } diff --git a/datastore/item_count_test.go b/datastore/item_count_test.go index ccb78ac2..54f307c2 100644 --- a/datastore/item_count_test.go +++ b/datastore/item_count_test.go @@ -4,13 +4,15 @@ import ( "sort" "testing" + "github.com/stretchr/testify/suite" + "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/datastore/datastoretest" - "github.com/stretchr/testify/suite" ) type ItemCountTestSuite struct { suite.Suite + dynamo *datastore.Dynamo } @@ -46,13 +48,13 @@ func (suite *ItemCountTestSuite) TestGetClientItemCount() { for _, item := range items { count, err := suite.dynamo.GetClientItemCount(item.ClientID) suite.Require().NoError(err, "GetClientItemCount should succeed") - suite.Assert().Equal(count.ItemCount, item.ItemCount, "ItemCount should match") + suite.Equal(count.ItemCount, item.ItemCount, "ItemCount should match") } // Non-exist client item count should succeed with count = 0. count, err := suite.dynamo.GetClientItemCount("client3") suite.Require().NoError(err, "Get non-exist ClientItemCount should succeed") - suite.Assert().Equal(count.ItemCount, 0) + suite.Equal(0, count.ItemCount) } func (suite *ItemCountTestSuite) TestUpdateClientItemCount() { @@ -81,7 +83,7 @@ func (suite *ItemCountTestSuite) TestUpdateClientItemCount() { clientCountItems[i].Version = 0 clientCountItems[i].LastPeriodChangeTime = 0 } - suite.Assert().Equal(expectedItems, clientCountItems) + suite.Equal(expectedItems, clientCountItems) } func TestItemCountTestSuite(t *testing.T) { diff --git a/datastore/sync_entity.go b/datastore/sync_entity.go index 690bf7ca..e87d8926 100644 --- a/datastore/sync_entity.go +++ b/datastore/sync_entity.go @@ -1,6 +1,8 @@ package datastore import ( + "context" + "errors" "fmt" "reflect" "sort" @@ -8,15 +10,16 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/expression" - "github.com/brave/go-sync/schema/protobuf/sync_pb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/google/uuid" "github.com/rs/zerolog/log" - uuid "github.com/satori/go.uuid" "google.golang.org/protobuf/proto" + + "github.com/brave/go-sync/schema/protobuf/sync_pb" ) const ( @@ -167,15 +170,15 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { // Write tag item for all data types, except for // the history type, which does not use tag items. if entity.ClientDefinedUniqueTag != nil && *entity.DataType != HistoryTypeID { - items := []*dynamodb.TransactWriteItem{} + items := make([]types.TransactWriteItem, 0, 2) // Additional item for ensuring tag's uniqueness for a specific client. item := NewServerClientUniqueTagItem(entity.ClientID, *entity.ClientDefinedUniqueTag, false) - av, err := dynamodbattribute.MarshalMap(*item) + av, err := attributevalue.MarshalMap(*item) if err != nil { return false, fmt.Errorf("error marshalling unique tag item to insert sync entity: %w", err) } - tagItem := &dynamodb.TransactWriteItem{ - Put: &dynamodb.Put{ + tagItem := types.TransactWriteItem{ + Put: &types.Put{ Item: av, ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), @@ -185,12 +188,12 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { } // Normal sync item - av, err = dynamodbattribute.MarshalMap(*entity) + av, err = attributevalue.MarshalMap(*entity) if err != nil { return false, fmt.Errorf("error marshlling sync item to insert sync entity: %w", err) } - syncItem := &dynamodb.TransactWriteItem{ - Put: &dynamodb.Put{ + syncItem := types.TransactWriteItem{ + Put: &types.Put{ Item: av, ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), @@ -201,11 +204,12 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { items = append(items, tagItem) items = append(items, syncItem) - _, err = dynamo.TransactWriteItems( + _, err = dynamo.TransactWriteItems(context.TODO(), &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { // Return conflict if insert condition failed. - if canceledException, ok := err.(*dynamodb.TransactionCanceledException); ok { + var canceledException *types.TransactionCanceledException + if errors.As(err, &canceledException) { for _, reason := range canceledException.CancellationReasons { if reason.Code != nil && *reason.Code == conditionalCheckFailed { return true, fmt.Errorf("error inserting sync item with client tag: %w", err) @@ -219,7 +223,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { } // Normal sync item - av, err := dynamodbattribute.MarshalMap(*entity) + av, err := attributevalue.MarshalMap(*entity) if err != nil { return false, fmt.Errorf("error marshalling sync item to insert sync entity: %w", err) } @@ -230,7 +234,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { ConditionExpression: expr.Condition(), TableName: aws.String(Table), } - _, err = dynamo.PutItem(input) + _, err = dynamo.PutItem(context.TODO(), input) if err != nil { return false, fmt.Errorf("error calling PutItem to insert sync item: %w", err) } @@ -241,7 +245,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { // tag item exists with the tag value for a specific client. func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) { tagItem := NewServerClientUniqueTagItemQuery(clientID, tag, true) - key, err := dynamodbattribute.MarshalMap(tagItem) + key, err := attributevalue.MarshalMap(tagItem) if err != nil { return false, fmt.Errorf("error marshalling key to check if server tag existed: %w", err) } @@ -252,7 +256,7 @@ func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bo TableName: aws.String(Table), } - out, err := dynamo.GetItem(input) + out, err := dynamo.GetItem(context.TODO(), input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if server tag existed: %w", err) } @@ -262,7 +266,7 @@ func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bo func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { primaryKey := PrimaryKey{ClientID: clientID, ID: ID} - key, err := dynamodbattribute.MarshalMap(primaryKey) + key, err := attributevalue.MarshalMap(primaryKey) if err != nil { return false, fmt.Errorf("error marshalling key to check if item existed: %w", err) @@ -274,7 +278,7 @@ func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { TableName: aws.String(Table), } - out, err := dynamo.GetItem(input) + out, err := dynamo.GetItem(context.TODO(), input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if item existed: %w", err) } @@ -287,7 +291,7 @@ func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { // we will write a tag item and a sync item. Items for all the entities in the // array would be written into DB in one transaction. func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) error { - items := []*dynamodb.TransactWriteItem{} + items := make([]types.TransactWriteItem, 0, len(entities)*2) for _, entity := range entities { // Create a condition for inserting new items only. cond := expression.AttributeNotExists(expression.Name(pk)) @@ -298,12 +302,12 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e // Additional item for ensuring tag's uniqueness for a specific client. item := NewServerClientUniqueTagItem(entity.ClientID, *entity.ServerDefinedUniqueTag, true) - av, err := dynamodbattribute.MarshalMap(*item) + av, err := attributevalue.MarshalMap(*item) if err != nil { return fmt.Errorf("error marshalling tag item to insert sync entity with server tag: %w", err) } - tagItem := &dynamodb.TransactWriteItem{ - Put: &dynamodb.Put{ + tagItem := types.TransactWriteItem{ + Put: &types.Put{ Item: av, ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), @@ -313,12 +317,12 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e } // Normal sync item - av, err = dynamodbattribute.MarshalMap(*entity) + av, err = attributevalue.MarshalMap(*entity) if err != nil { return fmt.Errorf("error marshalling sync item to insert sync entity with server tag: %w", err) } - syncItem := &dynamodb.TransactWriteItem{ - Put: &dynamodb.Put{ + syncItem := types.TransactWriteItem{ + Put: &types.Put{ Item: av, ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), @@ -331,7 +335,7 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e items = append(items, syncItem) } - _, err := dynamo.TransactWriteItems( + _, err := dynamo.TransactWriteItems(context.TODO(), &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { return fmt.Errorf("error writing sync entities with server tags in a transaction: %w", err) @@ -350,7 +354,7 @@ func (dynamo *Dynamo) DisableSyncChain(clientID string) error { Ctime: now, } - av, err := dynamodbattribute.MarshalMap(disabledMarker) + av, err := attributevalue.MarshalMap(disabledMarker) if err != nil { return fmt.Errorf("error marshalling disabled marker: %w", err) } @@ -360,7 +364,7 @@ func (dynamo *Dynamo) DisableSyncChain(clientID string) error { TableName: aws.String(Table), } - _, err = dynamo.PutItem(markerInput) + _, err = dynamo.PutItem(context.TODO(), markerInput) if err != nil { return fmt.Errorf("error calling PutItem to insert sync item: %w", err) } @@ -388,25 +392,22 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { TableName: aws.String(Table), } - out, err := dynamo.Query(input) + out, err := dynamo.Query(context.TODO(), input) if err != nil { return syncEntities, fmt.Errorf("error doing query to get updates: %w", err) } - count := *out.Count + count := out.Count - err = dynamodbattribute.UnmarshalListOfMaps(out.Items, &syncEntities) + err = attributevalue.UnmarshalListOfMaps(out.Items, &syncEntities) if err != nil { return syncEntities, fmt.Errorf("error unmarshalling updated sync entities: %w", err) } - var i, j int64 + var i, j int32 for i = 0; i < count; i += maxTransactDeleteItemSize { - j = i + maxTransactDeleteItemSize - if j > count { - j = count - } + j = min(i+maxTransactDeleteItemSize, count) - items := []*dynamodb.TransactWriteItem{} + items := make([]types.TransactWriteItem, 0, j-i) for _, item := range syncEntities[i:j] { if item.ID == disabledChainID { continue @@ -420,46 +421,45 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { return syncEntities, fmt.Errorf("error deleting sync entities for client %s: %w", clientID, err) } - writeItem := dynamodb.TransactWriteItem{ - Delete: &dynamodb.Delete{ + writeItem := types.TransactWriteItem{ + Delete: &types.Delete{ ConditionExpression: expr.Condition(), ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), TableName: aws.String(Table), - Key: map[string]*dynamodb.AttributeValue{ - pk: { - S: aws.String(item.ClientID), + Key: map[string]types.AttributeValue{ + pk: &types.AttributeValueMemberS{ + Value: item.ClientID, }, - sk: { - S: aws.String(item.ID), + sk: &types.AttributeValueMemberS{ + Value: item.ID, }, }, }, } - items = append(items, &writeItem) + items = append(items, writeItem) } else { // If row doesn't hold Mtime, delete as usual. - writeItem := dynamodb.TransactWriteItem{ - Delete: &dynamodb.Delete{ + writeItem := types.TransactWriteItem{ + Delete: &types.Delete{ TableName: aws.String(Table), - Key: map[string]*dynamodb.AttributeValue{ - pk: { - S: aws.String(item.ClientID), + Key: map[string]types.AttributeValue{ + pk: &types.AttributeValueMemberS{ + Value: item.ClientID, }, - sk: { - S: aws.String(item.ID), + sk: &types.AttributeValueMemberS{ + Value: item.ID, }, }, }, } - items = append(items, &writeItem) + items = append(items, writeItem) } - } - _, err = dynamo.TransactWriteItems(&dynamodb.TransactWriteItemsInput{TransactItems: items}) + _, err = dynamo.TransactWriteItems(context.TODO(), &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { return syncEntities, fmt.Errorf("error deleting sync entities for client %s: %w", clientID, err) } @@ -470,7 +470,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { // IsSyncChainDisabled checks whether a given sync chain has been deleted func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { - key, err := dynamodbattribute.MarshalMap(DisabledMarkerItemQuery{ + key, err := attributevalue.MarshalMap(DisabledMarkerItemQuery{ ClientID: clientID, ID: disabledChainID, }) @@ -483,7 +483,7 @@ func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { TableName: aws.String(Table), } - out, err := dynamo.GetItem(input) + out, err := dynamo.GetItem(context.TODO(), input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if sync chain disabled: %w", err) } @@ -494,7 +494,7 @@ func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { // UpdateSyncEntity updates a sync item in dynamoDB. func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bool, bool, error) { primaryKey := PrimaryKey{ClientID: entity.ClientID, ID: entity.ID} - key, err := dynamodbattribute.MarshalMap(primaryKey) + key, err := attributevalue.MarshalMap(primaryKey) if err != nil { return false, false, fmt.Errorf("error marshalling key to update sync entity: %w", err) } @@ -541,25 +541,25 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo if entity.Deleted != nil && entity.ClientDefinedUniqueTag != nil && *entity.Deleted && *entity.DataType != HistoryTypeID { pk := PrimaryKey{ ClientID: entity.ClientID, ID: clientTagItemPrefix + *entity.ClientDefinedUniqueTag} - tagItemKey, err := dynamodbattribute.MarshalMap(pk) + tagItemKey, err := attributevalue.MarshalMap(pk) if err != nil { return false, false, fmt.Errorf("error marshalling key to update sync entity: %w", err) } - items := []*dynamodb.TransactWriteItem{} - updateSyncItem := &dynamodb.TransactWriteItem{ - Update: &dynamodb.Update{ + items := make([]types.TransactWriteItem, 0, 2) + updateSyncItem := types.TransactWriteItem{ + Update: &types.Update{ Key: key, ExpressionAttributeNames: expr.Names(), ExpressionAttributeValues: expr.Values(), ConditionExpression: expr.Condition(), UpdateExpression: expr.Update(), - ReturnValuesOnConditionCheckFailure: aws.String(dynamodb.ReturnValueAllOld), + ReturnValuesOnConditionCheckFailure: types.ReturnValuesOnConditionCheckFailureAllOld, TableName: aws.String(Table), }, } - deleteTagItem := &dynamodb.TransactWriteItem{ - Delete: &dynamodb.Delete{ + deleteTagItem := types.TransactWriteItem{ + Delete: &types.Delete{ Key: tagItemKey, TableName: aws.String(Table), }, @@ -567,11 +567,12 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo items = append(items, updateSyncItem) items = append(items, deleteTagItem) - _, err = dynamo.TransactWriteItems( + _, err = dynamo.TransactWriteItems(context.TODO(), &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { // Return conflict if the update condition fails. - if canceledException, ok := err.(*dynamodb.TransactionCanceledException); ok { + var canceledException *types.TransactionCanceledException + if errors.As(err, &canceledException) { for _, reason := range canceledException.CancellationReasons { if reason.Code != nil && *reason.Code == conditionalCheckFailed { return true, false, nil @@ -594,24 +595,23 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo ExpressionAttributeValues: expr.Values(), ConditionExpression: expr.Condition(), UpdateExpression: expr.Update(), - ReturnValues: aws.String(dynamodb.ReturnValueAllOld), + ReturnValues: types.ReturnValueAllOld, TableName: aws.String(Table), } - out, err := dynamo.UpdateItem(input) + out, err := dynamo.UpdateItem(context.TODO(), input) if err != nil { - if aerr, ok := err.(awserr.Error); ok { + var conditionalCheckFailedException *types.ConditionalCheckFailedException + if errors.As(err, &conditionalCheckFailedException) { // Return conflict if the write condition fails. - if aerr.Code() == dynamodb.ErrCodeConditionalCheckFailedException { - return true, false, nil - } + return true, false, nil } return false, false, fmt.Errorf("error calling UpdateItem to update sync entity: %w", err) } // Unmarshal out.Attributes oldEntity := &SyncEntity{} - err = dynamodbattribute.UnmarshalMap(out.Attributes, oldEntity) + err = attributevalue.UnmarshalMap(out.Attributes, oldEntity) if err != nil { return false, false, fmt.Errorf("error unmarshalling old sync entity: %w", err) } @@ -664,52 +664,47 @@ func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFo FilterExpression: expr.Filter(), ProjectionExpression: aws.String(projPk), TableName: aws.String(Table), - Limit: aws.Int64(maxSize), + Limit: aws.Int32(int32(maxSize)), } - out, err := dynamo.Query(input) + out, err := dynamo.Query(context.TODO(), input) if err != nil { return false, syncEntities, fmt.Errorf("error doing query to get updates: %w", err) } - hasChangesRemaining := false - if out.LastEvaluatedKey != nil && len(out.LastEvaluatedKey) > 0 { - hasChangesRemaining = true - } + hasChangesRemaining := len(out.LastEvaluatedKey) > 0 - count := *(out.Count) + count := out.Count if count == 0 { // No updates return hasChangesRemaining, syncEntities, nil } // Use return (ClientID, ID) primary keys to get the actual items. - var outAv []map[string]*dynamodb.AttributeValue - var i, j int64 + var outAv []map[string]types.AttributeValue + var i, j int32 for i = 0; i < count; i += maxBatchGetItemSize { - j = i + maxBatchGetItemSize - if j > count { - j = count - } + j = min(i+maxBatchGetItemSize, count) batchInput := &dynamodb.BatchGetItemInput{ - RequestItems: map[string]*dynamodb.KeysAndAttributes{ + RequestItems: map[string]types.KeysAndAttributes{ Table: { Keys: out.Items[i:j], }, }, } - err := dynamo.BatchGetItemPages(batchInput, - func(batchOut *dynamodb.BatchGetItemOutput, last bool) bool { - outAv = append(outAv, batchOut.Responses[Table]...) - return last - }) - if err != nil { - return false, syncEntities, fmt.Errorf("error getting update items in a batch: %w", err) + // Use paginator to automatically handle UnprocessedKeys + paginator := dynamodb.NewBatchGetItemPaginator(dynamo.Client, batchInput) + for paginator.HasMorePages() { + batchOut, err := paginator.NextPage(context.TODO()) + if err != nil { + return false, syncEntities, fmt.Errorf("error getting update items in a batch: %w", err) + } + outAv = append(outAv, batchOut.Responses[Table]...) } } - err = dynamodbattribute.UnmarshalListOfMaps(outAv, &syncEntities) + err = attributevalue.UnmarshalListOfMaps(outAv, &syncEntities) if err != nil { return false, syncEntities, fmt.Errorf("error unmarshalling updated sync entities: %w", err) } @@ -732,19 +727,19 @@ func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFo func validatePBEntity(entity *sync_pb.SyncEntity) error { if entity == nil { - return fmt.Errorf("validate SyncEntity error: empty SyncEntity") + return errors.New("validate SyncEntity error: empty SyncEntity") } if entity.IdString == nil { - return fmt.Errorf("validate SyncEntity error: empty IdString") + return errors.New("validate SyncEntity error: empty IdString") } if entity.Version == nil { - return fmt.Errorf("validate SyncEntity error: empty Version") + return errors.New("validate SyncEntity error: empty Version") } if entity.Specifics == nil { - return fmt.Errorf("validate SyncEntity error: nil Specifics") + return errors.New("validate SyncEntity error: nil Specifics") } return nil @@ -785,7 +780,7 @@ func CreateDBSyncEntity(entity *sync_pb.SyncEntity, cacheGUID *string, clientID var originatorCacheGUID, originatorClientItemID *string if cacheGUID != nil { if *entity.Version == 0 { - id = uuid.NewV4().String() + id = uuid.New().String() } originatorCacheGUID = cacheGUID originatorClientItemID = entity.IdString diff --git a/datastore/sync_entity_test.go b/datastore/sync_entity_test.go index a20cd5ea..5c6d8919 100644 --- a/datastore/sync_entity_test.go +++ b/datastore/sync_entity_test.go @@ -7,16 +7,18 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/proto" + "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/datastore/datastoretest" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/stretchr/testify/suite" - "google.golang.org/protobuf/proto" ) type SyncEntityTestSuite struct { suite.Suite + dynamo *datastore.Dynamo } @@ -51,18 +53,18 @@ func (suite *SyncEntityTestSuite) TestNewServerClientUniqueTagItem() { // We can't know the exact value for Mtime & Ctime. Make sure they're set, // set zero value for subsequent tests - suite.Assert().NotNil(actualClientTag.Mtime) - suite.Assert().NotNil(actualClientTag.Ctime) - suite.Assert().NotNil(actualServerTag.Mtime) - suite.Assert().NotNil(actualServerTag.Ctime) + suite.NotNil(actualClientTag.Mtime) + suite.NotNil(actualClientTag.Ctime) + suite.NotNil(actualServerTag.Mtime) + suite.NotNil(actualServerTag.Ctime) actualClientTag.Mtime = nil actualClientTag.Ctime = nil actualServerTag.Mtime = nil actualServerTag.Ctime = nil - suite.Assert().Equal(expectedServerTag, actualServerTag) - suite.Assert().Equal(expectedClientTag, actualClientTag) + suite.Equal(expectedServerTag, actualServerTag) + suite.Equal(expectedClientTag, actualClientTag) } func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { @@ -89,11 +91,11 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { // Each InsertSyncEntity without client tag should result in one sync item saved. tagItems, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal( - 0, len(tagItems), "Insert without client tag should not insert tag items") + suite.Empty( + tagItems, "Insert without client tag should not insert tag items") syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1, entity2}) + suite.Equal([]datastore.SyncEntity{entity1, entity2}, syncItems) // Insert entity with client tag should result in one sync item and one tag // item saved. @@ -115,7 +117,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { entity4Copy.ID = "id4_copy" conflict, err := suite.dynamo.InsertSyncEntity(&entity4Copy) suite.Require().Error(err, "InsertSyncEntity with the same client tag and ClientID should fail") - suite.Assert().True(conflict, "Return conflict for duplicate client tag") + suite.True(conflict, "Return conflict for duplicate client tag") // Insert entity with the same client tag for other client should not fail. entity5 := entity3 @@ -130,16 +132,16 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { suite.Require().NoError(err, "ScanSyncEntities should succeed") expectedSyncItems := []datastore.SyncEntity{entity1, entity2, entity3, entity4, entity5} sort.Sort(datastore.SyncEntityByClientIDID(syncItems)) - suite.Assert().Equal(syncItems, expectedSyncItems) + suite.Equal(expectedSyncItems, syncItems) // Check tag items should be saved for entity3, entity4, entity5. tagItems, err = datastoretest.ScanTagItems(suite.dynamo) // Check that Ctime and Mtime have been set, reset to zero value for subsequent // tests - for i := 0; i < len(tagItems); i++ { - suite.Assert().NotNil(tagItems[i].Ctime) - suite.Assert().NotNil(tagItems[i].Mtime) + for i := range tagItems { + suite.NotNil(tagItems[i].Ctime) + suite.NotNil(tagItems[i].Mtime) tagItems[i].Ctime = nil tagItems[i].Mtime = nil @@ -152,7 +154,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { {ClientID: "client2", ID: "Client#tag1"}, } sort.Sort(datastore.TagItemByClientIDID(tagItems)) - suite.Assert().Equal(expectedTagItems, tagItems) + suite.Equal(expectedTagItems, tagItems) } func (suite *SyncEntityTestSuite) TestHasServerDefinedUniqueTag() { @@ -180,19 +182,19 @@ func (suite *SyncEntityTestSuite) TestHasServerDefinedUniqueTag() { hasTag, err := suite.dynamo.HasServerDefinedUniqueTag("client1", "tag1") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") - suite.Assert().Equal(hasTag, true) + suite.True(hasTag) hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client1", "tag2") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") - suite.Assert().Equal(hasTag, false) + suite.False(hasTag) hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client2", "tag1") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") - suite.Assert().Equal(hasTag, false) + suite.False(hasTag) hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client2", "tag2") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") - suite.Assert().Equal(hasTag, true) + suite.True(hasTag) } func (suite *SyncEntityTestSuite) TestHasItem() { @@ -220,19 +222,19 @@ func (suite *SyncEntityTestSuite) TestHasItem() { hasTag, err := suite.dynamo.HasItem("client1", "id1") suite.Require().NoError(err, "HasItem should succeed") - suite.Assert().Equal(hasTag, true) + suite.True(hasTag) hasTag, err = suite.dynamo.HasItem("client2", "id2") suite.Require().NoError(err, "HasItem should succeed") - suite.Assert().Equal(hasTag, true) + suite.True(hasTag) hasTag, err = suite.dynamo.HasItem("client2", "id3") suite.Require().NoError(err, "HasItem should succeed") - suite.Assert().Equal(hasTag, false) + suite.False(hasTag) hasTag, err = suite.dynamo.HasItem("client3", "id2") suite.Require().NoError(err, "HasItem should succeed") - suite.Assert().Equal(hasTag, false) + suite.False(hasTag) } func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { @@ -259,10 +261,10 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { // Check nothing is written to DB when it fails. syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(0, len(syncItems), "No items should be written if fail") + suite.Empty(syncItems, "No items should be written if fail") tagItems, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(0, len(tagItems), "No items should be written if fail") + suite.Empty(tagItems, "No items should be written if fail") entity2.ServerDefinedUniqueTag = aws.String("tag2") entity3 := entity1 @@ -278,15 +280,15 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { suite.Require().NoError(err, "ScanSyncEntities should succeed") expectedSyncItems := []datastore.SyncEntity{entity1, entity2, entity3} sort.Sort(datastore.SyncEntityByClientIDID(syncItems)) - suite.Assert().Equal(syncItems, expectedSyncItems) + suite.Equal(expectedSyncItems, syncItems) tagItems, err = datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") // Check that Ctime and Mtime have been set, reset to zero value for subsequent // tests - for i := 0; i < len(tagItems); i++ { - suite.Assert().NotNil(tagItems[i].Ctime) - suite.Assert().NotNil(tagItems[i].Mtime) + for i := range tagItems { + suite.NotNil(tagItems[i].Ctime) + suite.NotNil(tagItems[i].Mtime) tagItems[i].Ctime = nil tagItems[i].Mtime = nil @@ -298,7 +300,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { {ClientID: "client2", ID: "Server#tag1"}, } sort.Sort(datastore.TagItemByClientIDID(tagItems)) - suite.Assert().Equal(expectedTagItems, tagItems) + suite.Equal(expectedTagItems, tagItems) } func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { @@ -328,7 +330,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { // Check sync entities are inserted correctly in DB. syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1, entity2, entity3}) + suite.Equal([]datastore.SyncEntity{entity1, entity2, entity3}, syncItems) // Update without optional fields. updateEntity1 := entity1 @@ -340,8 +342,8 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity1.Specifics = []byte{3, 4} conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().True(deleted, "Delete operation should return true") + suite.False(conflict, "Successful update should not have conflict") + suite.True(deleted, "Delete operation should return true") // Update with optional fields. updateEntity2 := updateEntity1 @@ -354,8 +356,8 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity2.NonUniqueName = aws.String("non_unique_name") conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, *entity2.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().False(deleted, "Non-delete operation should return false") + suite.False(conflict, "Successful update should not have conflict") + suite.False(deleted, "Non-delete operation should return false") // Update with nil Folder and Deleted updateEntity3 := updateEntity1 @@ -364,8 +366,8 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity3.Deleted = nil conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity3, *entity3.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().False(deleted, "Non-delete operation should return false") + suite.False(conflict, "Successful update should not have conflict") + suite.False(deleted, "Non-delete operation should return false") // Reset these back to false because they will be the expected value in DB. updateEntity3.Folder = aws.Bool(false) updateEntity3.Deleted = aws.Bool(false) @@ -374,13 +376,13 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { // should return false. conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, 12345678) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().True(conflict, "Update with the same version should return conflict") - suite.Assert().False(deleted, "Conflict operation should return false for delete") + suite.True(conflict, "Update with the same version should return conflict") + suite.False(deleted, "Conflict operation should return false for delete") // Check sync entities are updated correctly in DB. syncItems, err = datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{updateEntity1, updateEntity2, updateEntity3}) + suite.Equal([]datastore.SyncEntity{updateEntity1, updateEntity2, updateEntity3}, syncItems) } func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { @@ -400,7 +402,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { } conflict, err := suite.dynamo.InsertSyncEntity(&entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - suite.Assert().False(conflict, "Successful insert should not have conflict") + suite.False(conflict, "Successful insert should not have conflict") updateEntity1 := entity1 updateEntity1.Version = aws.Int64(2) @@ -408,8 +410,8 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { updateEntity1.Mtime = aws.Int64(24242424) conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, 1) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().False(deleted, "Non-delete operation should return false") + suite.False(conflict, "Successful update should not have conflict") + suite.False(deleted, "Non-delete operation should return false") // should still succeed with the same version number, // since the version number should be ignored @@ -417,20 +419,20 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { updateEntity2.Mtime = aws.Int64(42424242) conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, 1) suite.Require().NoError(err, "UpdateSyncEntity should not return an error") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().False(deleted, "Non-delete operation should return false") + suite.False(conflict, "Successful update should not have conflict") + suite.False(deleted, "Non-delete operation should return false") updateEntity3 := entity1 updateEntity3.Deleted = aws.Bool(true) conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity3, 1) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().True(deleted, "Delete operation should return true") + suite.False(conflict, "Successful update should not have conflict") + suite.True(deleted, "Delete operation should return true") syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{updateEntity3}) + suite.Equal([]datastore.SyncEntity{updateEntity3}, syncItems) } func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { @@ -450,12 +452,12 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { } conflict, err := suite.dynamo.InsertSyncEntity(&entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - suite.Assert().False(conflict, "Successful insert should not have conflict") + suite.False(conflict, "Successful insert should not have conflict") // Check a tag item is inserted. tagItems, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(1, len(tagItems), "Tag item should be inserted") + suite.Len(tagItems, 1, "Tag item should be inserted") // Update it to version 23456789. updateEntity1 := entity1 @@ -466,39 +468,39 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { updateEntity1.Specifics = []byte{3, 4} conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().False(deleted, "Non-delete operation should return false") + suite.False(conflict, "Successful update should not have conflict") + suite.False(deleted, "Non-delete operation should return false") // Soft-delete the item with wrong version should get conflict. updateEntity1.Deleted = aws.Bool(true) conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().True(conflict, "Version mismatched update should have conflict") - suite.Assert().False(deleted, "Failed delete operation should return false") + suite.True(conflict, "Version mismatched update should have conflict") + suite.False(deleted, "Failed delete operation should return false") // Soft-delete the item with matched version. updateEntity1.Version = aws.Int64(34567890) conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity1, 23456789) suite.Require().NoError(err, "UpdateSyncEntity should succeed") - suite.Assert().False(conflict, "Successful update should not have conflict") - suite.Assert().True(deleted, "Delete operation should return true") + suite.False(conflict, "Successful update should not have conflict") + suite.True(deleted, "Delete operation should return true") // Check tag item is deleted. tagItems, err = datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(0, len(tagItems), "Tag item should be deleted") + suite.Empty(tagItems, "Tag item should be deleted") // Insert another item with the same client tag again. entity2 := entity1 entity2.ID = "id2" conflict, err = suite.dynamo.InsertSyncEntity(&entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - suite.Assert().False(conflict, "Successful insert should not have conflict") + suite.False(conflict, "Successful insert should not have conflict") // Check a tag item is inserted. tagItems, err = datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(1, len(tagItems), "Tag item should be inserted") + suite.Len(tagItems, 1, "Tag item should be inserted") } func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { @@ -553,52 +555,52 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { // Get all updates for type 123 and client1 using token = 0. hasChangesRemaining, syncItems, err := suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1, entity2}) - suite.Assert().False(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity1, entity2}, syncItems) + suite.False(hasChangesRemaining) // Get all updates for type 124 and client1 using token = 0. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(124, 0, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity3}) - suite.Assert().False(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity3}, syncItems) + suite.False(hasChangesRemaining) // Get all updates for type 123 and client2 using token = 0. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client2", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity4}) - suite.Assert().False(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity4}, syncItems) + suite.False(hasChangesRemaining) // Get all updates for type 124 and client2 using token = 0. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(124, 0, true, "client2", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(len(syncItems), 0) - suite.Assert().False(hasChangesRemaining) + suite.Empty(syncItems) + suite.False(hasChangesRemaining) // Test maxSize will limit the return entries size, and hasChangesRemaining // should be true when there are more updates available in the DB. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 1) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1}) - suite.Assert().True(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity1}, syncItems) + suite.True(hasChangesRemaining) // Test when num of query items equal to the limit, hasChangesRemaining should // be true. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 2) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1, entity2}) - suite.Assert().True(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity1, entity2}, syncItems) + suite.True(hasChangesRemaining) // Test fetchFolders will remove folder items if false hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, false, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity2}) - suite.Assert().False(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity2}, syncItems) + suite.False(hasChangesRemaining) // Get all updates for a type for a client using mtime of one item as token. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 12345678, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity2}) - suite.Assert().False(hasChangesRemaining) + suite.Equal([]datastore.SyncEntity{entity2}, syncItems) + suite.False(hasChangesRemaining) // Test batch is working correctly for over 100 items err = datastoretest.ResetTable(suite.dynamo) @@ -631,15 +633,15 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 300) suite.Require().NoError(err, "GetUpdatesForType should succeed") sort.Sort(datastore.SyncEntityByMtime(expectedSyncItems)) - suite.Assert().Equal(syncItems, expectedSyncItems) - suite.Assert().False(hasChangesRemaining) + suite.Equal(expectedSyncItems, syncItems) + suite.False(hasChangesRemaining) // Test that when maxGUBatchSize is smaller than total updates, the first n // items ordered by Mtime should be returned. hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 200) suite.Require().NoError(err, "GetUpdatesForType should succeed") - suite.Assert().Equal(syncItems, expectedSyncItems[0:200]) - suite.Assert().True(hasChangesRemaining) + suite.Equal(syncItems, expectedSyncItems[0:200]) + suite.True(hasChangesRemaining) } func (suite *SyncEntityTestSuite) TestCreateDBSyncEntity() { @@ -691,67 +693,67 @@ func (suite *SyncEntityTestSuite) TestCreateDBSyncEntity() { suite.Require().NoError(err, "CreateDBSyncEntity should succeed") // Check ID is replaced with a server-generated ID. - suite.Assert().NotEqual( + suite.NotEqual( dbEntity.ID, *pbEntity.IdString, "ID should be a server-generated ID and not equal to the passed IdString") expectedDBEntity.ID = dbEntity.ID // Check Mtime and Ctime should be provided by the server if client does not // provide it. - suite.Assert().NotNil( + suite.NotNil( dbEntity.Ctime, "Mtime should not be nil if client did not pass one") - suite.Assert().NotNil( + suite.NotNil( dbEntity.Mtime, "Mtime should not be nil if client did not pass one") - suite.Assert().Equal( + suite.Equal( *dbEntity.Mtime, *dbEntity.Ctime, "Server should generate the same value for mtime and ctime when they're not provided by the client") expectedDBEntity.Ctime = dbEntity.Ctime expectedDBEntity.Mtime = dbEntity.Mtime expectedDBEntity.DataTypeMtime = aws.String("47745#" + strconv.FormatInt(*dbEntity.Mtime, 10)) - suite.Assert().Equal(dbEntity, &expectedDBEntity) - suite.Assert().Nil(dbEntity.ExpirationTime) + suite.Equal(&expectedDBEntity, dbEntity) + suite.Nil(dbEntity.ExpirationTime) pbEntity.Deleted = nil pbEntity.Folder = nil dbEntity, err = datastore.CreateDBSyncEntity(&pbEntity, guid, "client1") suite.Require().NoError(err, "CreateDBSyncEntity should succeed") - suite.Assert().False(*dbEntity.Deleted, "Default value should be set for Deleted for new entities") - suite.Assert().False(*dbEntity.Folder, "Default value should be set for Deleted for new entities") - suite.Assert().Nil(dbEntity.ExpirationTime) + suite.False(*dbEntity.Deleted, "Default value should be set for Deleted for new entities") + suite.False(*dbEntity.Folder, "Default value should be set for Deleted for new entities") + suite.Nil(dbEntity.ExpirationTime) // Check the case when Ctime and Mtime are provided by the client. pbEntity.Ctime = aws.Int64(12345678) pbEntity.Mtime = aws.Int64(12345678) dbEntity, err = datastore.CreateDBSyncEntity(&pbEntity, guid, "client1") suite.Require().NoError(err, "CreateDBSyncEntity should succeed") - suite.Assert().Equal(*dbEntity.Ctime, *pbEntity.Ctime, "Client's Ctime should be respected") - suite.Assert().NotEqual(*dbEntity.Mtime, *pbEntity.Mtime, "Client's Mtime should be replaced") - suite.Assert().Nil(dbEntity.ExpirationTime) + suite.Equal(*dbEntity.Ctime, *pbEntity.Ctime, "Client's Ctime should be respected") + suite.NotEqual(*dbEntity.Mtime, *pbEntity.Mtime, "Client's Mtime should be replaced") + suite.Nil(dbEntity.ExpirationTime) // When cacheGUID is nil, ID should be kept and no originator info are filled. dbEntity, err = datastore.CreateDBSyncEntity(&pbEntity, nil, "client1") suite.Require().NoError(err, "CreateDBSyncEntity should succeed") - suite.Assert().Equal(dbEntity.ID, *pbEntity.IdString) - suite.Assert().Nil(dbEntity.OriginatorCacheGUID) - suite.Assert().Nil(dbEntity.OriginatorClientItemID) - suite.Assert().Nil(dbEntity.ExpirationTime) + suite.Equal(dbEntity.ID, *pbEntity.IdString) + suite.Nil(dbEntity.OriginatorCacheGUID) + suite.Nil(dbEntity.OriginatorClientItemID) + suite.Nil(dbEntity.ExpirationTime) // Check that when updating from a previous version with guid, ID will not be // replaced. pbEntity.Version = aws.Int64(1) dbEntity, err = datastore.CreateDBSyncEntity(&pbEntity, guid, "client1") suite.Require().NoError(err, "CreateDBSyncEntity should succeed") - suite.Assert().Equal(dbEntity.ID, *pbEntity.IdString) - suite.Assert().Nil(dbEntity.Deleted, "Deleted won't apply its default value for updated entities") - suite.Assert().Nil(dbEntity.Folder, "Deleted won't apply its default value for updated entities") - suite.Assert().Nil(dbEntity.ExpirationTime) + suite.Equal(dbEntity.ID, *pbEntity.IdString) + suite.Nil(dbEntity.Deleted, "Deleted won't apply its default value for updated entities") + suite.Nil(dbEntity.Folder, "Deleted won't apply its default value for updated entities") + suite.Nil(dbEntity.ExpirationTime) // Empty unique position should be marshalled to nil without error. pbEntity.UniquePosition = nil dbEntity, err = datastore.CreateDBSyncEntity(&pbEntity, guid, "client1") suite.Require().NoError(err) - suite.Assert().Nil(dbEntity.UniquePosition) - suite.Assert().Nil(dbEntity.ExpirationTime) + suite.Nil(dbEntity.UniquePosition) + suite.Nil(dbEntity.ExpirationTime) // A history entity should have the client tag hash as the ID, // and an expiration time. @@ -759,15 +761,15 @@ func (suite *SyncEntityTestSuite) TestCreateDBSyncEntity() { pbEntity.Specifics = &sync_pb.EntitySpecifics{SpecificsVariant: historyEntitySpecific} dbEntity, err = datastore.CreateDBSyncEntity(&pbEntity, guid, "client1") suite.Require().NoError(err) - suite.Assert().Equal(dbEntity.ID, "client_tag") + suite.Equal("client_tag", dbEntity.ID) expectedExpirationTime := time.Now().Unix() + datastore.HistoryExpirationIntervalSecs - suite.Assert().Greater(*dbEntity.ExpirationTime+2, expectedExpirationTime) - suite.Assert().Less(*dbEntity.ExpirationTime-2, expectedExpirationTime) + suite.Greater(*dbEntity.ExpirationTime+2, expectedExpirationTime) + suite.Less(*dbEntity.ExpirationTime-2, expectedExpirationTime) // Empty specifics should report marshal error. pbEntity.Specifics = nil _, err = datastore.CreateDBSyncEntity(&pbEntity, guid, "client1") - suite.Assert().NotNil(err.Error(), "empty specifics should fail") + suite.NotNil(err.Error(), "empty specifics should fail") } func (suite *SyncEntityTestSuite) TestCreatePBSyncEntity() { @@ -829,19 +831,19 @@ func (suite *SyncEntityTestSuite) TestCreatePBSyncEntity() { suite.Require().NoError(err, "json.Marshal should succeed") s2, err := json.Marshal(&expectedPBEntity) suite.Require().NoError(err, "json.Marshal should succeed") - suite.Assert().Equal(s1, s2) + suite.Equal(s1, s2) // Nil UniquePosition should be unmarshalled as nil without error. dbEntity.UniquePosition = nil pbEntity, err = datastore.CreatePBSyncEntity(&dbEntity) suite.Require().NoError(err, "CreatePBSyncEntity should succeed") - suite.Assert().Nil(pbEntity.UniquePosition) + suite.Nil(pbEntity.UniquePosition) // Nil Specifics should be unmarshalled as nil without error. dbEntity.Specifics = nil pbEntity, err = datastore.CreatePBSyncEntity(&dbEntity) suite.Require().NoError(err, "CreatePBSyncEntity should succeed") - suite.Assert().Nil(pbEntity.Specifics) + suite.Nil(pbEntity.Specifics) } func (suite *SyncEntityTestSuite) TestDisableSyncChain() { @@ -851,9 +853,9 @@ func (suite *SyncEntityTestSuite) TestDisableSyncChain() { suite.Require().NoError(err, "DisableSyncChain should succeed") e, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(1, len(e)) - suite.Assert().Equal(clientID, e[0].ClientID) - suite.Assert().Equal(id, e[0].ID) + suite.Len(e, 1) + suite.Equal(clientID, e[0].ClientID) + suite.Equal(id, e[0].ID) } func (suite *SyncEntityTestSuite) TestIsSyncChainDisabled() { @@ -861,13 +863,13 @@ func (suite *SyncEntityTestSuite) TestIsSyncChainDisabled() { disabled, err := suite.dynamo.IsSyncChainDisabled(clientID) suite.Require().NoError(err, "IsSyncChainDisabled should succeed") - suite.Assert().Equal(false, disabled) + suite.False(disabled) err = suite.dynamo.DisableSyncChain(clientID) suite.Require().NoError(err, "DisableSyncChain should succeed") disabled, err = suite.dynamo.IsSyncChainDisabled(clientID) suite.Require().NoError(err, "IsSyncChainDisabled should succeed") - suite.Assert().Equal(true, disabled) + suite.True(disabled) } func (suite *SyncEntityTestSuite) TestClearServerData() { @@ -888,15 +890,15 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { e, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(1, len(e)) + suite.Len(e, 1) e, err = suite.dynamo.ClearServerData(entity.ClientID) suite.Require().NoError(err, "ClearServerData should succeed") - suite.Assert().Equal(1, len(e)) + suite.Len(e, 1) e, err = datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(0, len(e)) + suite.Empty(e) // Test clear tagged items entity1 := datastore.SyncEntity{ @@ -921,23 +923,23 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { e, err = datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(2, len(e), "No items should be written if fail") + suite.Len(e, 2, "No items should be written if fail") t, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(2, len(t), "No items should be written if fail") + suite.Len(t, 2, "No items should be written if fail") e, err = suite.dynamo.ClearServerData(entity.ClientID) suite.Require().NoError(err, "ClearServerData should succeed") - suite.Assert().Equal(4, len(e)) + suite.Len(e, 4) e, err = datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") - suite.Assert().Equal(0, len(e), "No items should be written if fail") + suite.Empty(e, "No items should be written if fail") t, err = datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") - suite.Assert().Equal(0, len(t), "No items should be written if fail") + suite.Empty(t, "No items should be written if fail") } func TestSyncEntityTestSuite(t *testing.T) { diff --git a/go.mod b/go.mod index 7c950573..df758edf 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,37 @@ module github.com/brave/go-sync go 1.24.0 require ( - github.com/aws/aws-sdk-go v1.55.8 + github.com/aws/aws-sdk-go-v2 v1.32.7 + github.com/aws/aws-sdk-go-v2/config v1.28.6 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.24 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.7.59 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.1 github.com/brave-intl/bat-go/libs v0.0.0-20250620104757-9e2f8ff87fd8 - github.com/getsentry/sentry-go v0.34.0 + github.com/getsentry/sentry-go v0.39.0 github.com/go-chi/chi/v5 v5.2.3 + github.com/google/uuid v1.6.0 github.com/prometheus/client_golang v1.23.2 - github.com/redis/go-redis/v9 v9.14.0 + github.com/redis/go-redis/v9 v9.17.0 github.com/rs/zerolog v1.34.0 - github.com/satori/go.uuid v1.2.0 github.com/stretchr/testify v1.11.1 google.golang.org/protobuf v1.36.8 ) require ( github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.47 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -35,6 +52,7 @@ require ( github.com/prometheus/procfs v0.16.1 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/satori/go.uuid v1.2.0 // indirect github.com/shengdoushi/base58 v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/throttled/throttled/v2 v2.13.0 // indirect diff --git a/go.sum b/go.sum index afb8d5d3..7d2dc97a 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,42 @@ github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/v github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/aws/aws-sdk-go v1.55.8 h1:JRmEUbU52aJQZ2AjX4q4Wu7t4uZjOu71uyNmaWlUkJQ= -github.com/aws/aws-sdk-go v1.55.8/go.mod h1:ZkViS9AqA6otK+JBBNH2++sx1sgxrPKcSzPPvQkUtXk= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= +github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.24 h1:oB+JFeqQrLSkMqVVWf3zQq5uUPpO84sQbwqoQ2AXYX0= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.24/go.mod h1:b2gkt7DFR5t8nhDoG7XfLM8RER+kKTxRxkeeXVhps30= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.7.59 h1:9tiJl90a05ktuXPrtFFQzUpbCdz5cX4StHPjlfhMH7M= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.7.59/go.mod h1:SO7V5LvuKWqc5ylPyrAla40HzxGrHHwiCZD2tO3kUbw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 h1:zXFLuEuMMUOvEARXFUVJdfqZ4bvvSgdGRq/ATcrQxzM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26/go.mod h1:3o2Wpy0bogG1kyOPrgkXA8pgIfEEv0+m19O9D5+W8y8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.1 h1:SOJ3xkgrw8W0VQgyBUeep74yuf8kWALToFxNNwlHFvg= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.1/go.mod h1:J8xqRbx7HIc8ids2P8JbrKx9irONPEYq7Z1FpLDpi3I= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.11 h1:lBa70oU+Vmfjpl6cqjF1ZIJ0hiWkB7uQe5pGozE4yYg= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.11/go.mod h1:HywkMgYwY0uaybPvvctx6fkm3L1ssRKeGv7TPZ6OQ/M= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.7 h1:EqGlayejoCRXmnVC6lXl6phCm9R2+k35e0gWsO9G5DI= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.7/go.mod h1:BTw+t+/E5F3ZnDai/wSOYM54WUVjSdewE7Jvwtb7o+w= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/brave-intl/bat-go/libs v0.0.0-20250620104757-9e2f8ff87fd8 h1:DUBBDWBwozxpET6HCTDSRKAfyiqKRNvOyeqI463epaI= @@ -29,8 +63,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4= -github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE= +github.com/getsentry/sentry-go v0.39.0 h1:uhnexj8PNCyCve37GSqxXOeXHh4cJNLNNB4w70Jtgo0= +github.com/getsentry/sentry-go v0.39.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s= github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= @@ -58,6 +92,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= @@ -108,8 +144,8 @@ github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= -github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE= -github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/redis/go-redis/v9 v9.17.0 h1:K6E+ZlYN95KSMmZeEQPbU/c++wfmEvfFB17yEAq/VhM= +github.com/redis/go-redis/v9 v9.17.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= diff --git a/middleware/auth.go b/middleware/auth.go index 99108b91..312d04ed 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,9 +4,10 @@ import ( "context" "net/http" - "github.com/brave/go-sync/auth" - syncContext "github.com/brave/go-sync/context" "github.com/rs/zerolog/log" + + "github.com/brave/go-sync/auth" + syncContext "github.com/brave/go-sync/synccontext" ) // Auth verifies the token provided is valid, and sets the client id in context diff --git a/middleware/disabled_chain.go b/middleware/disabled_chain.go index 282e5382..7aeb266c 100644 --- a/middleware/disabled_chain.go +++ b/middleware/disabled_chain.go @@ -3,11 +3,12 @@ package middleware import ( "net/http" - syncContext "github.com/brave/go-sync/context" - "github.com/brave/go-sync/datastore" - "github.com/brave/go-sync/schema/protobuf/sync_pb" "github.com/rs/zerolog/log" "google.golang.org/protobuf/proto" + + "github.com/brave/go-sync/datastore" + "github.com/brave/go-sync/schema/protobuf/sync_pb" + syncContext "github.com/brave/go-sync/synccontext" ) // DisabledChain is a middleware to check for disabled sync chains referenced in a request, diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 277f5236..087586f7 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -3,17 +3,18 @@ package middleware_test import ( "bytes" "context" - "fmt" + "errors" "net/http" "net/http/httptest" "testing" "time" + "github.com/stretchr/testify/suite" + "github.com/brave/go-sync/auth/authtest" - syncContext "github.com/brave/go-sync/context" "github.com/brave/go-sync/datastore/datastoretest" "github.com/brave/go-sync/middleware" - "github.com/stretchr/testify/suite" + syncContext "github.com/brave/go-sync/synccontext" ) type MiddlewareTestSuite struct { @@ -42,7 +43,7 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { ctx = context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) next = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - suite.Require().Equal(false, true) + suite.Fail("Should not reach this point") }) handler = middleware.DisabledChain(next) req, err = http.NewRequestWithContext(ctx, "POST", "v2/command/", bytes.NewBuffer([]byte{})) @@ -53,12 +54,11 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // DB error datastore = new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(false, fmt.Errorf("unable to query db")) + datastore.On("IsSyncChainDisabled", clientID).Return(false, errors.New("unable to query db")) ctx = context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) next = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) handler = middleware.DisabledChain(next) - rr = httptest.NewRecorder() req, err = http.NewRequestWithContext(ctx, "POST", "v2/command/", bytes.NewBuffer([]byte{})) suite.Require().NoError(err, "NewRequestWithContext should succeed") rr = httptest.NewRecorder() @@ -71,7 +71,7 @@ func (suite *MiddlewareTestSuite) TestAuthMiddleware() { next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { ctx := r.Context() clientID := ctx.Value(syncContext.ContextKeyClientID) - suite.Require().NotNil(clientID, "Client ID should be set by auth middleware") + suite.NotNil(clientID, "Client ID should be set by auth middleware") }) handler := middleware.Auth(next) diff --git a/server/server.go b/server/server.go index 67d854e0..f9b38b31 100644 --- a/server/server.go +++ b/server/server.go @@ -16,17 +16,18 @@ import ( "github.com/brave-intl/bat-go/libs/handlers" "github.com/brave-intl/bat-go/libs/logging" batware "github.com/brave-intl/bat-go/libs/middleware" - "github.com/brave/go-sync/cache" - syncContext "github.com/brave/go-sync/context" - "github.com/brave/go-sync/controller" - "github.com/brave/go-sync/datastore" - "github.com/brave/go-sync/middleware" sentry "github.com/getsentry/sentry-go" "github.com/go-chi/chi/v5" chiware "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" + + "github.com/brave/go-sync/cache" + "github.com/brave/go-sync/controller" + "github.com/brave/go-sync/datastore" + "github.com/brave/go-sync/middleware" + syncContext "github.com/brave/go-sync/synccontext" ) var ( @@ -141,6 +142,7 @@ func StartServer() { healthCheckActive = false // disable health check time.Sleep(60 * time.Second) + //nolint:errcheck // Error during shutdown in signal handler is acceptable srv.Shutdown(serverCtx) }() @@ -156,6 +158,7 @@ func StartServer() { } err := srv.ListenAndServe() + //nolint:errorlint // Error during shutdown in signal handler is acceptable if err == http.ErrServerClosed { log.Info().Msg("HTTP server closed") } else if err != nil { diff --git a/server/server_test.go b/server/server_test.go index 22f4ad56..e9af1204 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,14 +2,16 @@ package server_test import ( "context" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" - "github.com/brave/go-sync/server" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/brave/go-sync/server" ) var ( @@ -17,28 +19,29 @@ var ( serverCtx context.Context ) -func init() { +func TestMain(m *testing.M) { testCtx, logger := server.SetupLogger(context.Background()) serverCtx, mux = server.SetupRouter(testCtx, logger) + m.Run() } func TestPing(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) - assert.Nil(t, err) + require.NoError(t, err) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req.WithContext(serverCtx)) assert.Equal(t, http.StatusOK, rr.Code) expected := "." - actual, err := ioutil.ReadAll(rr.Result().Body) - assert.Nil(t, err) + actual, err := io.ReadAll(rr.Result().Body) + require.NoError(t, err) assert.Equal(t, expected, string(actual)) } func TestCommand(t *testing.T) { req, err := http.NewRequest("POST", "/v2/command/", nil) - assert.Nil(t, err) + require.NoError(t, err) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req.WithContext(serverCtx)) diff --git a/context/context.go b/synccontext/context.go similarity index 91% rename from context/context.go rename to synccontext/context.go index 2cad2ba4..e5b2019a 100644 --- a/context/context.go +++ b/synccontext/context.go @@ -1,4 +1,4 @@ -package context +package synccontext type Key string