From c6ba8210459c6b53a84af3be357048e125f5fb1f Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Fri, 24 Oct 2025 00:48:20 -0500 Subject: [PATCH 01/11] fix go.mod --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 951cfe3b..1755ba80 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.24 github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.7.59 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.1 - github.com/aws/smithy-go v1.22.1 github.com/brave-intl/bat-go/libs v0.0.0-20250620104757-9e2f8ff87fd8 github.com/getsentry/sentry-go v0.34.0 github.com/go-chi/chi/v5 v5.2.2 @@ -34,6 +33,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect From e1854b70e2e6fafa1e49cf8107fc00ea0d8f1062 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Fri, 24 Oct 2025 03:11:08 -0500 Subject: [PATCH 02/11] fix: improve context propagation --- command/command.go | 38 +++--- command/command_test.go | 52 ++++---- command/server_defined_unique_entity.go | 9 +- command/server_defined_unique_entity_test.go | 7 +- controller/controller.go | 2 +- datastore/datastore.go | 24 ++-- datastore/datastoretest/mock_datastore.go | 46 +++---- datastore/instrumented_datastore.go | 45 +++---- datastore/item_count.go | 16 +-- datastore/item_count_test.go | 11 +- datastore/sync_entity.go | 44 +++---- datastore/sync_entity_test.go | 121 ++++++++++--------- middleware/disabled_chain.go | 2 +- 13 files changed, 213 insertions(+), 204 deletions(-) diff --git a/command/command.go b/command/command.go index 0bdd8c06..2533c070 100644 --- a/command/command.go +++ b/command/command.go @@ -33,7 +33,7 @@ const ( // handleGetUpdatesRequest handles GetUpdatesMessage and fills // GetUpdatesResponse. Target sync entities in the database will be updated or // deleted based on the client's requests. -func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessage, guRsp *sync_pb.GetUpdatesResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { +func handleGetUpdatesRequest(ctx context.Context, cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessage, guRsp *sync_pb.GetUpdatesResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { errCode := sync_pb.SyncEnums_SUCCESS // default value, might be changed later isNewClient := guMsg.GetUpdatesOrigin != nil && *guMsg.GetUpdatesOrigin == sync_pb.SyncEnums_NEW_CLIENT isPoll := guMsg.GetUpdatesOrigin != nil && *guMsg.GetUpdatesOrigin == sync_pb.SyncEnums_PERIODIC @@ -41,7 +41,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag // Reject the request if client has >= 50 devices in the chain. activeDevices := 0 for { - hasChangesRemaining, syncEntities, err := db.GetUpdatesForType(deviceInfoTypeID, 0, false, clientID, int64(maxGUBatchSize)) + hasChangesRemaining, syncEntities, err := db.GetUpdatesForType(ctx, deviceInfoTypeID, 0, false, clientID, int64(maxGUBatchSize)) if err != nil { log.Error().Err(err).Msgf("db.GetUpdatesForType failed for type %v", deviceInfoTypeID) errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -68,7 +68,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag } // Insert initial records if needed. - err := InsertServerDefinedUniqueEntities(db, clientID) + err := InsertServerDefinedUniqueEntities(ctx, db, clientID) if err != nil { log.Error().Err(err).Msg("Create server defined unique entities failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -135,7 +135,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag } curMaxSize := int64(maxSize) - int64(len(guRsp.Entries)) - hasChangesRemaining, entities, err := db.GetUpdatesForType(int(*fromProgressMarker.DataTypeId), token, fetchFolders, clientID, curMaxSize) + hasChangesRemaining, entities, err := db.GetUpdatesForType(ctx, int(*fromProgressMarker.DataTypeId), token, fetchFolders, clientID, curMaxSize) if err != nil { log.Error().Err(err).Msgf("db.GetUpdatesForType failed for type %v", *fromProgressMarker.DataTypeId) errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -194,8 +194,8 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag return &errCode, nil } -func getItemCounts(cache *cache.Cache, db datastore.Datastore, clientID string) (*datastore.ClientItemCounts, int, int, error) { - itemCounts, err := db.GetClientItemCount(clientID) +func getItemCounts(ctx context.Context, cache *cache.Cache, db datastore.Datastore, clientID string) (*datastore.ClientItemCounts, int, int, error) { + itemCounts, err := db.GetClientItemCount(ctx, clientID) if err != nil { return nil, 0, 0, err } @@ -222,7 +222,7 @@ func getInterimItemCounts(cache *cache.Cache, clientID string, clearCache bool) // For each commit entry: // - new sync entity is created and inserted into the database if version is 0. // - existed sync entity will be updated if version is greater than 0. -func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, commitRsp *sync_pb.CommitResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { +func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *sync_pb.CommitMessage, commitRsp *sync_pb.CommitResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { if commitMsg == nil { return nil, fmt.Errorf("nil commitMsg is received") } @@ -232,7 +232,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c return &errCode, nil } - itemCounts, newNormalCount, newHistoryCount, err := getItemCounts(cache, db, clientID) + itemCounts, newNormalCount, newHistoryCount, err := getItemCounts(ctx, cache, db, clientID) if err != nil { log.Error().Err(err).Msg("Get client's item count failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -283,7 +283,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c *entityToCommit.Version = *entityToCommit.Mtime if *entityToCommit.DataType == datastore.HistoryTypeID { // Check if item exists using client_unique_tag - isUpdateOp, err = db.HasItem(clientID, *entityToCommit.ClientDefinedUniqueTag) + isUpdateOp, err = db.HasItem(ctx, clientID, *entityToCommit.ClientDefinedUniqueTag) if err != nil { log.Error().Err(err).Msg("Insert history sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -305,7 +305,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // Insert all non-history items. For history items, ignore any items above history quoto // and lie to the client about the objects being synced instead of returning OVER_QUOTA // so the client can continue to sync other entities. - conflict, err := db.InsertSyncEntity(entityToCommit) + conflict, err := db.InsertSyncEntity(ctx, entityToCommit) if err != nil { log.Error().Err(err).Msg("Insert sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -330,7 +330,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c } } } else { // Update - conflict, deleted, err := db.UpdateSyncEntity(entityToCommit, oldVersion) + conflict, deleted, err := db.UpdateSyncEntity(ctx, entityToCommit, oldVersion) if err != nil { log.Error().Err(err).Msg("Update sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -378,7 +378,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c cache.SetTypeMtime(context.Background(), clientID, dataType, mtime) } - err = db.UpdateClientItemCount(itemCounts, newNormalCount, newHistoryCount) + err = db.UpdateClientItemCount(ctx, itemCounts, newNormalCount, newHistoryCount) if err != nil { // We only impose a soft quota limit on the item count for each client, so // we only log the error without further actions here. The reason of this @@ -395,18 +395,18 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // handleClearServerDataRequest handles clearing user data from the datastore and cache // and fills the response -func handleClearServerDataRequest(cache *cache.Cache, db datastore.Datastore, _ *sync_pb.ClearServerDataMessage, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { +func handleClearServerDataRequest(ctx context.Context, cache *cache.Cache, db datastore.Datastore, _ *sync_pb.ClearServerDataMessage, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { errCode := sync_pb.SyncEnums_SUCCESS var err error - err = db.DisableSyncChain(clientID) + err = db.DisableSyncChain(ctx, clientID) if err != nil { log.Error().Err(err).Msg("Failed to disable sync chain") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR return &errCode, err } - syncEntities, err := db.ClearServerData(clientID) + syncEntities, err := db.ClearServerData(ctx, clientID) if err != nil { errCode = sync_pb.SyncEnums_TRANSIENT_ERROR return &errCode, err @@ -433,7 +433,7 @@ func handleClearServerDataRequest(cache *cache.Cache, db datastore.Datastore, _ // HandleClientToServerMessage handles the protobuf ClientToServerMessage and // fills the protobuf ClientToServerResponse. -func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerMessage, pbRsp *sync_pb.ClientToServerResponse, db datastore.Datastore, clientID string) error { +func HandleClientToServerMessage(ctx context.Context, cache *cache.Cache, pb *sync_pb.ClientToServerMessage, pbRsp *sync_pb.ClientToServerResponse, db datastore.Datastore, clientID string) error { // Create ClientToServerResponse and fill general fields for both GU and // Commit. pbRsp.StoreBirthday = aws.String(storeBirthday) @@ -447,7 +447,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM } else if *pb.MessageContents == sync_pb.ClientToServerMessage_GET_UPDATES { guRsp := &sync_pb.GetUpdatesResponse{} pbRsp.GetUpdates = guRsp - pbRsp.ErrorCode, err = handleGetUpdatesRequest(cache, pb.GetUpdates, guRsp, db, clientID) + pbRsp.ErrorCode, err = handleGetUpdatesRequest(ctx, cache, pb.GetUpdates, guRsp, db, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) @@ -461,7 +461,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM } else if *pb.MessageContents == sync_pb.ClientToServerMessage_COMMIT { commitRsp := &sync_pb.CommitResponse{} pbRsp.Commit = commitRsp - pbRsp.ErrorCode, err = handleCommitRequest(cache, pb.Commit, commitRsp, db, clientID) + pbRsp.ErrorCode, err = handleCommitRequest(ctx, cache, pb.Commit, commitRsp, db, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) @@ -475,7 +475,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM } else if *pb.MessageContents == sync_pb.ClientToServerMessage_CLEAR_SERVER_DATA { csdRsp := &sync_pb.ClearServerDataResponse{} pbRsp.ClearServerData = csdRsp - pbRsp.ErrorCode, err = handleClearServerDataRequest(cache, db, pb.ClearServerData, clientID) + pbRsp.ErrorCode, err = handleClearServerDataRequest(ctx, cache, db, pb.ClearServerData, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) diff --git a/command/command_test.go b/command/command_test.go index b1c282a9..4ae37c3e 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -221,7 +221,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { // Commit and check response. suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) @@ -241,7 +241,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { marker, sync_pb.SyncEnums_GU_TRIGGER, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -265,7 +265,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) @@ -287,7 +287,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { marker, sync_pb.SyncEnums_GU_TRIGGER, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -312,7 +312,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { msg = getClientToServerCommitMsg(entries) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) @@ -328,7 +328,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { marker, sync_pb.SyncEnums_GU_TRIGGER, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -345,7 +345,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_NewClient() { rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -389,7 +389,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_GUBatchSize() { // Commit and check response. suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) @@ -413,7 +413,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) @@ -438,7 +438,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) @@ -459,7 +459,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) @@ -476,7 +476,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) @@ -496,7 +496,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(4, len(rsp.Commit.Entryresponse)) @@ -518,7 +518,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo msg := getClientToServerCommitMsg([]*sync_pb.SyncEntity{child0}) rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(1, len(rsp.Commit.Entryresponse)) @@ -547,7 +547,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(6, len(rsp.Commit.Entryresponse)) @@ -562,7 +562,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo marker, sync_pb.SyncEnums_GU_TRIGGER, true, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Equal(6, len(rsp.GetUpdates.Entries)) @@ -592,7 +592,7 @@ func insertSyncEntitiesWithoutUpdateCache( for _, entry := range entries { dbEntry, err := datastore.CreateDBSyncEntity(entry, nil, clientID) suite.Require().NoError(err, "Create db entity from pb entity should succeed") - _, err = suite.dynamo.InsertSyncEntity(dbEntry) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), dbEntry) suite.Require().NoError(err, "Insert sync entity should succeed") val, err := suite.cache.Get(context.Background(), clientID+"#"+strconv.Itoa(*dbEntry.DataType), false) @@ -616,7 +616,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(3, len(rsp.Commit.Entryresponse)) @@ -657,7 +657,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba marker, sync_pb.SyncEnums_PERIODIC, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Assert().Equal(0, len(rsp.GetUpdates.Entries)) @@ -680,7 +680,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(1, len(rsp.Commit.Entryresponse)) @@ -698,7 +698,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba marker, sync_pb.SyncEnums_PERIODIC, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Assert().Equal(2, len(rsp.GetUpdates.Entries)) @@ -719,7 +719,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Sk rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(1, len(rsp.Commit.Entryresponse)) @@ -746,7 +746,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Sk marker, sync_pb.SyncEnums_GU_TRIGGER, true, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Equal(1, len(rsp.GetUpdates.Entries)) @@ -765,7 +765,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Assert().Equal(2, len(rsp.Commit.Entryresponse)) @@ -788,7 +788,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch marker, sync_pb.SyncEnums_PERIODIC, true, &clientBatch) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Equal(2, len(rsp.GetUpdates.Entries)) @@ -805,7 +805,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch marker, sync_pb.SyncEnums_PERIODIC, true, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Equal(1, len(rsp.GetUpdates.Entries)) diff --git a/command/server_defined_unique_entity.go b/command/server_defined_unique_entity.go index a4a4acd1..8a0357c0 100644 --- a/command/server_defined_unique_entity.go +++ b/command/server_defined_unique_entity.go @@ -1,13 +1,14 @@ package command import ( + "context" "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/brave/go-sync/datastore" "github.com/brave/go-sync/schema/protobuf/sync_pb" - "github.com/satori/go.uuid" + uuid "github.com/satori/go.uuid" ) const ( @@ -41,11 +42,11 @@ func createServerDefinedUniqueEntity(name string, serverDefinedTag string, clien // InsertServerDefinedUniqueEntities inserts the server defined unique tag // entities if it is not in the DB yet for a specific client. -func InsertServerDefinedUniqueEntities(db datastore.Datastore, clientID string) error { +func InsertServerDefinedUniqueEntities(ctx context.Context, db datastore.Datastore, clientID string) error { var entities []*datastore.SyncEntity // Check if they're existed already for this client. // If yes, just return directly. - ready, err := db.HasServerDefinedUniqueTag(clientID, nigoriTag) + ready, err := db.HasServerDefinedUniqueTag(ctx, clientID, nigoriTag) if err != nil { return fmt.Errorf("error checking if entity with a server tag existed: %w", err) } @@ -89,7 +90,7 @@ func InsertServerDefinedUniqueEntities(db datastore.Datastore, clientID string) } // Start a transaction to insert all server defined unique entities - err = db.InsertSyncEntitiesWithServerTags(entities) + err = db.InsertSyncEntitiesWithServerTags(ctx, entities) if err != nil { return fmt.Errorf("error inserting entities with server tags: %w", err) } diff --git a/command/server_defined_unique_entity_test.go b/command/server_defined_unique_entity_test.go index 56e06faa..92510642 100644 --- a/command/server_defined_unique_entity_test.go +++ b/command/server_defined_unique_entity_test.go @@ -1,6 +1,7 @@ package command_test import ( + "context" "sort" "testing" @@ -45,10 +46,10 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TearDownTest() { func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEntities() { suite.Require().NoError( - command.InsertServerDefinedUniqueEntities(suite.dynamo, "client1"), + command.InsertServerDefinedUniqueEntities(context.Background(), suite.dynamo, "client1"), "InsertServerDefinedUniqueEntities should succeed") suite.Require().NoError( - command.InsertServerDefinedUniqueEntities(suite.dynamo, "client1"), + command.InsertServerDefinedUniqueEntities(context.Background(), suite.dynamo, "client1"), "InsertServerDefinedUniqueEntities again for a same client should succeed") expectedSyncAttrsMap := map[string]*SyncAttrs{ @@ -155,7 +156,7 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEn suite.Assert().Equal(0, len(expectedSyncAttrsMap)) suite.Require().NoError( - command.InsertServerDefinedUniqueEntities(suite.dynamo, "client2"), + command.InsertServerDefinedUniqueEntities(context.Background(), suite.dynamo, "client2"), "InsertServerDefinedUniqueEntities should succeed for another client") } diff --git a/controller/controller.go b/controller/controller.go index 7566072c..55a959d9 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -72,7 +72,7 @@ func Command(cache *cache.Cache, db datastore.Datastore) http.HandlerFunc { } pbRsp := &sync_pb.ClientToServerResponse{} - err = command.HandleClientToServerMessage(cache, pb, pbRsp, db, clientID) + err = command.HandleClientToServerMessage(ctx, cache, pb, pbRsp, db, clientID) if err != nil { log.Error().Err(err).Msg("Handle command message failed") http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/datastore/datastore.go b/datastore/datastore.go index 5e3e5451..217642f9 100644 --- a/datastore/datastore.go +++ b/datastore/datastore.go @@ -1,30 +1,32 @@ package datastore +import "context" + // Datastore abstracts over the underlying datastore. type Datastore interface { // Insert a new sync entity. - InsertSyncEntity(entity *SyncEntity) (bool, error) + InsertSyncEntity(ctx context.Context, entity *SyncEntity) (bool, error) // Insert a series of sync entities in a write transaction. - InsertSyncEntitiesWithServerTags(entities []*SyncEntity) error + InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*SyncEntity) error // Update an existing sync entity. - UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) + UpdateSyncEntity(ctx context.Context, entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) // Get updates for a specific type which are modified after the time of // client token for a given client. Besides the array of sync entities, a // boolean value indicating whether there are more updates to query in the // next batch is returned. - GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) + GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) // Check if a server-defined unique tag is in the datastore. - HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) + HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (bool, error) // Get the count of sync items for a client. - GetClientItemCount(clientID string) (*ClientItemCounts, error) + GetClientItemCount(ctx context.Context, clientID string) (*ClientItemCounts, error) // Update the count of sync items for a client. - UpdateClientItemCount(counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error + UpdateClientItemCount(ctx context.Context, counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error // ClearServerData deletes all items for a given clientID - ClearServerData(clientID string) ([]SyncEntity, error) + ClearServerData(ctx context.Context, clientID string) ([]SyncEntity, error) // DisableSyncChain marks a chain as disabled so no further updates or commits can happen - DisableSyncChain(clientID string) error + DisableSyncChain(ctx context.Context, clientID string) error // IsSyncChainDisabled checks whether a given sync chain is deleted - IsSyncChainDisabled(clientID string) (bool, error) + IsSyncChainDisabled(ctx context.Context, clientID string) (bool, error) // Checks if sync item exists for a client - HasItem(clientID string, ID string) (bool, error) + HasItem(ctx context.Context, clientID string, ID string) (bool, error) } diff --git a/datastore/datastoretest/mock_datastore.go b/datastore/datastoretest/mock_datastore.go index 09665d40..4e3083e8 100644 --- a/datastore/datastoretest/mock_datastore.go +++ b/datastore/datastoretest/mock_datastore.go @@ -1,6 +1,8 @@ package datastoretest import ( + "context" + "github.com/brave/go-sync/datastore" "github.com/stretchr/testify/mock" ) @@ -11,66 +13,66 @@ type MockDatastore struct { } // InsertSyncEntity mocks calls to InsertSyncEntity -func (m *MockDatastore) InsertSyncEntity(entity *datastore.SyncEntity) (bool, error) { - args := m.Called(entity) +func (m *MockDatastore) InsertSyncEntity(ctx context.Context, entity *datastore.SyncEntity) (bool, error) { + args := m.Called(ctx, entity) return args.Bool(0), args.Error(1) } // InsertSyncEntitiesWithServerTags mocks calls to InsertSyncEntitiesWithServerTags -func (m *MockDatastore) InsertSyncEntitiesWithServerTags(entities []*datastore.SyncEntity) error { - args := m.Called(entities) +func (m *MockDatastore) InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*datastore.SyncEntity) error { + args := m.Called(ctx, entities) return args.Error(0) } // UpdateSyncEntity mocks calls to UpdateSyncEntity -func (m *MockDatastore) UpdateSyncEntity(entity *datastore.SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { - args := m.Called(entity, oldVersion) +func (m *MockDatastore) UpdateSyncEntity(ctx context.Context, entity *datastore.SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { + args := m.Called(ctx, entity, oldVersion) return args.Bool(0), args.Bool(1), args.Error(2) } // GetUpdatesForType mocks calls to GetUpdatesForType -func (m *MockDatastore) GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []datastore.SyncEntity, error) { - args := m.Called(dataType, clientToken, fetchFolders, clientID, maxSize) +func (m *MockDatastore) GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []datastore.SyncEntity, error) { + args := m.Called(ctx, dataType, clientToken, fetchFolders, clientID, maxSize) return args.Bool(0), args.Get(1).([]datastore.SyncEntity), args.Error(2) } // HasServerDefinedUniqueTag mocks calls to HasServerDefinedUniqueTag -func (m *MockDatastore) HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) { - args := m.Called(clientID, tag) +func (m *MockDatastore) HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (bool, error) { + args := m.Called(ctx, clientID, tag) return args.Bool(0), args.Error(1) } -func (m *MockDatastore) HasItem(clientID string, ID string) (bool, error) { - args := m.Called(clientID, ID) +func (m *MockDatastore) HasItem(ctx context.Context, clientID string, ID string) (bool, error) { + args := m.Called(ctx, clientID, ID) return args.Bool(0), args.Error(1) } // GetClientItemCount mocks calls to GetClientItemCount -func (m *MockDatastore) GetClientItemCount(clientID string) (*datastore.ClientItemCounts, error) { - args := m.Called(clientID) +func (m *MockDatastore) GetClientItemCount(ctx context.Context, clientID string) (*datastore.ClientItemCounts, error) { + args := m.Called(ctx, clientID) return &datastore.ClientItemCounts{ClientID: clientID, ID: clientID}, args.Error(1) } // UpdateClientItemCount mocks calls to UpdateClientItemCount -func (m *MockDatastore) UpdateClientItemCount(counts *datastore.ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { - args := m.Called(counts, newNormalItemCount, newHistoryItemCount) +func (m *MockDatastore) UpdateClientItemCount(ctx context.Context, counts *datastore.ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { + args := m.Called(ctx, counts, newNormalItemCount, newHistoryItemCount) return args.Error(0) } // ClearServerData mocks calls to ClearServerData -func (m *MockDatastore) ClearServerData(clientID string) ([]datastore.SyncEntity, error) { - args := m.Called(clientID) +func (m *MockDatastore) ClearServerData(ctx context.Context, clientID string) ([]datastore.SyncEntity, error) { + args := m.Called(ctx, clientID) return args.Get(0).([]datastore.SyncEntity), args.Error(1) } // DisableSyncChain mocks calls to DisableSyncChain -func (m *MockDatastore) DisableSyncChain(clientID string) error { - args := m.Called(clientID) +func (m *MockDatastore) DisableSyncChain(ctx context.Context, clientID string) error { + args := m.Called(ctx, clientID) return args.Error(0) } // IsSyncChainDisabled mocks calls to IsSyncChainDisabled -func (m *MockDatastore) IsSyncChainDisabled(clientID string) (bool, error) { - args := m.Called(clientID) +func (m *MockDatastore) IsSyncChainDisabled(ctx context.Context, clientID string) (bool, error) { + args := m.Called(ctx, clientID) return args.Bool(0), args.Error(1) } diff --git a/datastore/instrumented_datastore.go b/datastore/instrumented_datastore.go index e2e82732..9765ed26 100644 --- a/datastore/instrumented_datastore.go +++ b/datastore/instrumented_datastore.go @@ -7,6 +7,7 @@ package datastore //go:generate gowrap gen -p github.com/brave/go-sync/datastore -i Datastore -t ../.prom-gowrap.tmpl -o instrumented_datastore.go import ( + "context" "time" "github.com/prometheus/client_golang/prometheus" @@ -38,7 +39,7 @@ func NewDatastoreWithPrometheus(base Datastore, instanceName string) DatastoreWi } // ClearServerData implements Datastore -func (_d DatastoreWithPrometheus) ClearServerData(clientID string) (sa1 []SyncEntity, err error) { +func (_d DatastoreWithPrometheus) ClearServerData(ctx context.Context, clientID string) (sa1 []SyncEntity, err error) { _since := time.Now() defer func() { result := "ok" @@ -48,11 +49,11 @@ func (_d DatastoreWithPrometheus) ClearServerData(clientID string) (sa1 []SyncEn datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "ClearServerData", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.ClearServerData(clientID) + return _d.base.ClearServerData(ctx, clientID) } // DisableSyncChain implements Datastore -func (_d DatastoreWithPrometheus) DisableSyncChain(clientID string) (err error) { +func (_d DatastoreWithPrometheus) DisableSyncChain(ctx context.Context, clientID string) (err error) { _since := time.Now() defer func() { result := "ok" @@ -62,11 +63,11 @@ func (_d DatastoreWithPrometheus) DisableSyncChain(clientID string) (err error) datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "DisableSyncChain", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.DisableSyncChain(clientID) + return _d.base.DisableSyncChain(ctx, clientID) } // GetClientItemCount implements Datastore -func (_d DatastoreWithPrometheus) GetClientItemCount(clientID string) (counts *ClientItemCounts, err error) { +func (_d DatastoreWithPrometheus) GetClientItemCount(ctx context.Context, clientID string) (counts *ClientItemCounts, err error) { _since := time.Now() defer func() { result := "ok" @@ -76,11 +77,11 @@ func (_d DatastoreWithPrometheus) GetClientItemCount(clientID string) (counts *C datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetClientItemCount", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.GetClientItemCount(clientID) + return _d.base.GetClientItemCount(ctx, clientID) } // GetUpdatesForType implements Datastore -func (_d DatastoreWithPrometheus) GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (b1 bool, sa1 []SyncEntity, err error) { +func (_d DatastoreWithPrometheus) GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (b1 bool, sa1 []SyncEntity, err error) { _since := time.Now() defer func() { result := "ok" @@ -90,11 +91,11 @@ func (_d DatastoreWithPrometheus) GetUpdatesForType(dataType int, clientToken in datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetUpdatesForType", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.GetUpdatesForType(dataType, clientToken, fetchFolders, clientID, maxSize) + return _d.base.GetUpdatesForType(ctx, dataType, clientToken, fetchFolders, clientID, maxSize) } // HasItem implements Datastore -func (_d DatastoreWithPrometheus) HasItem(clientID string, ID string) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) HasItem(ctx context.Context, clientID string, ID string) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -104,11 +105,11 @@ func (_d DatastoreWithPrometheus) HasItem(clientID string, ID string) (b1 bool, datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "HasItem", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.HasItem(clientID, ID) + return _d.base.HasItem(ctx, clientID, ID) } // HasServerDefinedUniqueTag implements Datastore -func (_d DatastoreWithPrometheus) HasServerDefinedUniqueTag(clientID string, tag string) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -118,11 +119,11 @@ func (_d DatastoreWithPrometheus) HasServerDefinedUniqueTag(clientID string, tag datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "HasServerDefinedUniqueTag", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.HasServerDefinedUniqueTag(clientID, tag) + return _d.base.HasServerDefinedUniqueTag(ctx, clientID, tag) } // InsertSyncEntitiesWithServerTags implements Datastore -func (_d DatastoreWithPrometheus) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) (err error) { +func (_d DatastoreWithPrometheus) InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*SyncEntity) (err error) { _since := time.Now() defer func() { result := "ok" @@ -132,11 +133,11 @@ func (_d DatastoreWithPrometheus) InsertSyncEntitiesWithServerTags(entities []*S datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "InsertSyncEntitiesWithServerTags", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.InsertSyncEntitiesWithServerTags(entities) + return _d.base.InsertSyncEntitiesWithServerTags(ctx, entities) } // InsertSyncEntity implements Datastore -func (_d DatastoreWithPrometheus) InsertSyncEntity(entity *SyncEntity) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) InsertSyncEntity(ctx context.Context, entity *SyncEntity) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -146,11 +147,11 @@ func (_d DatastoreWithPrometheus) InsertSyncEntity(entity *SyncEntity) (b1 bool, datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "InsertSyncEntity", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.InsertSyncEntity(entity) + return _d.base.InsertSyncEntity(ctx, entity) } // IsSyncChainDisabled implements Datastore -func (_d DatastoreWithPrometheus) IsSyncChainDisabled(clientID string) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) IsSyncChainDisabled(ctx context.Context, clientID string) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -160,11 +161,11 @@ func (_d DatastoreWithPrometheus) IsSyncChainDisabled(clientID string) (b1 bool, datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "IsSyncChainDisabled", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.IsSyncChainDisabled(clientID) + return _d.base.IsSyncChainDisabled(ctx, clientID) } // UpdateClientItemCount implements Datastore -func (_d DatastoreWithPrometheus) UpdateClientItemCount(counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) (err error) { +func (_d DatastoreWithPrometheus) UpdateClientItemCount(ctx context.Context, counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) (err error) { _since := time.Now() defer func() { result := "ok" @@ -174,11 +175,11 @@ func (_d DatastoreWithPrometheus) UpdateClientItemCount(counts *ClientItemCounts datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "UpdateClientItemCount", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.UpdateClientItemCount(counts, newNormalItemCount, newHistoryItemCount) + return _d.base.UpdateClientItemCount(ctx, counts, newNormalItemCount, newHistoryItemCount) } // UpdateSyncEntity implements Datastore -func (_d DatastoreWithPrometheus) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { +func (_d DatastoreWithPrometheus) UpdateSyncEntity(ctx context.Context, entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -188,5 +189,5 @@ func (_d DatastoreWithPrometheus) UpdateSyncEntity(entity *SyncEntity, oldVersio datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "UpdateSyncEntity", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.UpdateSyncEntity(entity, oldVersion) + return _d.base.UpdateSyncEntity(ctx, entity, oldVersion) } diff --git a/datastore/item_count.go b/datastore/item_count.go index 13a965e2..caa7aa39 100644 --- a/datastore/item_count.go +++ b/datastore/item_count.go @@ -51,7 +51,7 @@ func (counts *ClientItemCounts) SumHistoryCounts() int { counts.HistoryItemCountPeriod4 } -func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCounts) error { +func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(ctx context.Context, counts *ClientItemCounts) error { now := time.Now().Unix() if counts.Version < CurrentCountVersion { if counts.ItemCount > 0 { @@ -75,7 +75,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou TableName: aws.String(Table), Select: types.SelectCount, } - out, err := dynamo.Query(context.TODO(), historyCountInput) + out, err := dynamo.Query(ctx, historyCountInput) if err != nil { return fmt.Errorf("error querying history item count: %w", err) } @@ -101,7 +101,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou TableName: aws.String(Table), Select: types.SelectCount, } - out, err = dynamo.Query(context.TODO(), normalCountInput) + out, err = dynamo.Query(ctx, normalCountInput) if err != nil { return fmt.Errorf("error querying history item count: %w", err) } @@ -129,7 +129,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou // GetClientItemCount returns the count of non-deleted sync items stored for // a given client. -func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, error) { +func (dynamo *Dynamo) GetClientItemCount(ctx context.Context, clientID string) (*ClientItemCounts, error) { primaryKey := PrimaryKey{ClientID: clientID, ID: clientID} key, err := attributevalue.MarshalMap(primaryKey) if err != nil { @@ -141,7 +141,7 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return nil, fmt.Errorf("error getting an item-count item: %w", err) } @@ -157,7 +157,7 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er clientItemCounts.ID = clientID } - if err = dynamo.initRealCountsAndUpdateHistoryCounts(clientItemCounts); err != nil { + if err = dynamo.initRealCountsAndUpdateHistoryCounts(ctx, clientItemCounts); err != nil { return nil, err } @@ -166,7 +166,7 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er // UpdateClientItemCount updates the count of non-deleted sync items for a // given client stored in the dynamoDB. -func (dynamo *Dynamo) UpdateClientItemCount(counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { +func (dynamo *Dynamo) UpdateClientItemCount(ctx context.Context, counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { counts.HistoryItemCountPeriod4 += newHistoryItemCount counts.ItemCount += newNormalItemCount @@ -180,7 +180,7 @@ func (dynamo *Dynamo) UpdateClientItemCount(counts *ClientItemCounts, newNormalI TableName: aws.String(Table), } - _, err = dynamo.PutItem(context.TODO(), input) + _, err = dynamo.PutItem(ctx, input) if err != nil { return fmt.Errorf("error updating item-count item in dynamoDB: %w", err) } diff --git a/datastore/item_count_test.go b/datastore/item_count_test.go index ccb78ac2..a404747e 100644 --- a/datastore/item_count_test.go +++ b/datastore/item_count_test.go @@ -1,6 +1,7 @@ package datastore_test import ( + "context" "sort" "testing" @@ -40,17 +41,17 @@ func (suite *ItemCountTestSuite) TestGetClientItemCount() { for _, item := range items { existing := datastore.ClientItemCounts{ClientID: item.ClientID, ID: item.ID, Version: datastore.CurrentCountVersion} suite.Require().NoError( - suite.dynamo.UpdateClientItemCount(&existing, item.ItemCount, 0)) + suite.dynamo.UpdateClientItemCount(context.Background(), &existing, item.ItemCount, 0)) } for _, item := range items { - count, err := suite.dynamo.GetClientItemCount(item.ClientID) + count, err := suite.dynamo.GetClientItemCount(context.Background(), item.ClientID) suite.Require().NoError(err, "GetClientItemCount should succeed") suite.Assert().Equal(count.ItemCount, item.ItemCount, "ItemCount should match") } // Non-exist client item count should succeed with count = 0. - count, err := suite.dynamo.GetClientItemCount("client3") + count, err := suite.dynamo.GetClientItemCount(context.Background(), "client3") suite.Require().NoError(err, "Get non-exist ClientItemCount should succeed") suite.Assert().Equal(count.ItemCount, 0) } @@ -67,10 +68,10 @@ func (suite *ItemCountTestSuite) TestUpdateClientItemCount() { } for _, item := range items { - count, err := suite.dynamo.GetClientItemCount(item.ClientID) + count, err := suite.dynamo.GetClientItemCount(context.Background(), item.ClientID) suite.Require().NoError(err) suite.Require().NoError( - suite.dynamo.UpdateClientItemCount(count, item.ItemCount, 0)) + suite.dynamo.UpdateClientItemCount(context.Background(), count, item.ItemCount, 0)) } clientCountItems, err := datastoretest.ScanClientItemCounts(suite.dynamo) diff --git a/datastore/sync_entity.go b/datastore/sync_entity.go index 3e831216..ccea954e 100644 --- a/datastore/sync_entity.go +++ b/datastore/sync_entity.go @@ -158,7 +158,7 @@ func NewServerClientUniqueTagItemQuery(clientID string, tag string, isServer boo // write a sync item along with a tag item to ensure the uniqueness of the // client tag. Otherwise, only a sync item is written into DB without using // transactions. -func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { +func (dynamo *Dynamo) InsertSyncEntity(ctx context.Context, entity *SyncEntity) (bool, error) { // Create a condition for inserting new items only. cond := expression.AttributeNotExists(expression.Name(pk)) expr, err := expression.NewBuilder().WithCondition(cond).Build() @@ -203,7 +203,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { items = append(items, tagItem) items = append(items, syncItem) - _, err = dynamo.TransactWriteItems(context.TODO(), + _, err = dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { // Return conflict if insert condition failed. @@ -233,7 +233,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { ConditionExpression: expr.Condition(), TableName: aws.String(Table), } - _, err = dynamo.PutItem(context.TODO(), input) + _, err = dynamo.PutItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling PutItem to insert sync item: %w", err) } @@ -242,7 +242,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { // HasServerDefinedUniqueTag check the tag item to see if there is already a // tag item exists with the tag value for a specific client. -func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) { +func (dynamo *Dynamo) HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (bool, error) { tagItem := NewServerClientUniqueTagItemQuery(clientID, tag, true) key, err := attributevalue.MarshalMap(tagItem) if err != nil { @@ -255,7 +255,7 @@ func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bo TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if server tag existed: %w", err) } @@ -263,7 +263,7 @@ func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bo return out.Item != nil, nil } -func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { +func (dynamo *Dynamo) HasItem(ctx context.Context, clientID string, ID string) (bool, error) { primaryKey := PrimaryKey{ClientID: clientID, ID: ID} key, err := attributevalue.MarshalMap(primaryKey) @@ -277,7 +277,7 @@ func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if item existed: %w", err) } @@ -289,7 +289,7 @@ func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { // server-defined unique tags. To ensure the uniqueness, for each sync entity, // we will write a tag item and a sync item. Items for all the entities in the // array would be written into DB in one transaction. -func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) error { +func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*SyncEntity) error { items := make([]types.TransactWriteItem, 0, len(entities)*2) for _, entity := range entities { // Create a condition for inserting new items only. @@ -334,7 +334,7 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e items = append(items, syncItem) } - _, err := dynamo.TransactWriteItems(context.TODO(), + _, err := dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { return fmt.Errorf("error writing sync entities with server tags in a transaction: %w", err) @@ -343,7 +343,7 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e } // DisableSyncChain marks a chain as disabled so no further updates or commits can happen -func (dynamo *Dynamo) DisableSyncChain(clientID string) error { +func (dynamo *Dynamo) DisableSyncChain(ctx context.Context, clientID string) error { now := aws.Int64(time.Now().UnixMilli()) disabledMarker := DisabledMarkerItem{ ClientID: clientID, @@ -363,7 +363,7 @@ func (dynamo *Dynamo) DisableSyncChain(clientID string) error { TableName: aws.String(Table), } - _, err = dynamo.PutItem(context.TODO(), markerInput) + _, err = dynamo.PutItem(ctx, markerInput) if err != nil { return fmt.Errorf("error calling PutItem to insert sync item: %w", err) } @@ -372,7 +372,7 @@ func (dynamo *Dynamo) DisableSyncChain(clientID string) error { } // ClearServerData deletes all items for a given clientID -func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { +func (dynamo *Dynamo) ClearServerData(ctx context.Context, clientID string) ([]SyncEntity, error) { syncEntities := []SyncEntity{} pkb := expression.Key(pk) pkv := expression.Value(clientID) @@ -391,7 +391,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { TableName: aws.String(Table), } - out, err := dynamo.Query(context.TODO(), input) + out, err := dynamo.Query(ctx, input) if err != nil { return syncEntities, fmt.Errorf("error doing query to get updates: %w", err) } @@ -462,7 +462,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { } - _, err = dynamo.TransactWriteItems(context.TODO(), &dynamodb.TransactWriteItemsInput{TransactItems: items}) + _, err = dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { return syncEntities, fmt.Errorf("error deleting sync entities for client %s: %w", clientID, err) } @@ -472,7 +472,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { } // IsSyncChainDisabled checks whether a given sync chain has been deleted -func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { +func (dynamo *Dynamo) IsSyncChainDisabled(ctx context.Context, clientID string) (bool, error) { key, err := attributevalue.MarshalMap(DisabledMarkerItemQuery{ ClientID: clientID, ID: disabledChainID, @@ -486,7 +486,7 @@ func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if sync chain disabled: %w", err) } @@ -495,7 +495,7 @@ func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { } // UpdateSyncEntity updates a sync item in dynamoDB. -func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bool, bool, error) { +func (dynamo *Dynamo) UpdateSyncEntity(ctx context.Context, entity *SyncEntity, oldVersion int64) (bool, bool, error) { primaryKey := PrimaryKey{ClientID: entity.ClientID, ID: entity.ID} key, err := attributevalue.MarshalMap(primaryKey) if err != nil { @@ -570,7 +570,7 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo items = append(items, updateSyncItem) items = append(items, deleteTagItem) - _, err = dynamo.TransactWriteItems(context.TODO(), + _, err = dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { // Return conflict if the update condition fails. @@ -602,7 +602,7 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo TableName: aws.String(Table), } - out, err := dynamo.UpdateItem(context.TODO(), input) + out, err := dynamo.UpdateItem(ctx, input) if err != nil { var conditionalCheckFailedException *types.ConditionalCheckFailedException if errors.As(err, &conditionalCheckFailedException) { @@ -634,7 +634,7 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo // To do this in dynamoDB, we use (ClientID, DataType#Mtime) as GSI to get a // list of (ClientID, ID) primary keys with the given condition, then read the // actual sync item using the list of primary keys. -func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) { +func (dynamo *Dynamo) GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) { syncEntities := []SyncEntity{} // Get (ClientID, ID) pairs which are updates after mtime for a data type, @@ -670,7 +670,7 @@ func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFo Limit: aws.Int32(int32(maxSize)), } - out, err := dynamo.Query(context.TODO(), input) + out, err := dynamo.Query(ctx, input) if err != nil { return false, syncEntities, fmt.Errorf("error doing query to get updates: %w", err) } @@ -705,7 +705,7 @@ func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFo // Use paginator to automatically handle UnprocessedKeys paginator := dynamodb.NewBatchGetItemPaginator(dynamo.Client, batchInput) for paginator.HasMorePages() { - batchOut, err := paginator.NextPage(context.TODO()) + batchOut, err := paginator.NextPage(ctx) if err != nil { return false, syncEntities, fmt.Errorf("error getting update items in a batch: %w", err) } diff --git a/datastore/sync_entity_test.go b/datastore/sync_entity_test.go index 5cc24b4c..612cbcd0 100644 --- a/datastore/sync_entity_test.go +++ b/datastore/sync_entity_test.go @@ -1,6 +1,7 @@ package datastore_test import ( + "context" "encoding/json" "sort" "strconv" @@ -79,11 +80,11 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { } entity2 := entity1 entity2.ID = "id2" - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity with other ID should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity1) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().Error(err, "InsertSyncEntity with the same ClientID and ID should fail") // Each InsertSyncEntity without client tag should result in one sync item saved. @@ -100,20 +101,20 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { entity3 := entity1 entity3.ID = "id3" entity3.ClientDefinedUniqueTag = aws.String("tag1") - _, err = suite.dynamo.InsertSyncEntity(&entity3) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity3) suite.Require().NoError(err, "InsertSyncEntity should succeed") // Insert entity with different tag for same ClientID should succeed. entity4 := entity3 entity4.ID = "id4" entity4.ClientDefinedUniqueTag = aws.String("tag2") - _, err = suite.dynamo.InsertSyncEntity(&entity4) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity4) suite.Require().NoError(err, "InsertSyncEntity with different server tag should succeed") // Insert entity with the same client tag and ClientID should fail with conflict. entity4Copy := entity4 entity4Copy.ID = "id4_copy" - conflict, err := suite.dynamo.InsertSyncEntity(&entity4Copy) + conflict, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity4Copy) suite.Require().Error(err, "InsertSyncEntity with the same client tag and ClientID should fail") suite.Assert().True(conflict, "Return conflict for duplicate client tag") @@ -121,7 +122,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { entity5 := entity3 entity5.ClientID = "client2" entity5.ID = "id5" - _, err = suite.dynamo.InsertSyncEntity(&entity5) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity5) suite.Require().NoError(err, "InsertSyncEntity with the same client tag for another client should succeed") @@ -175,22 +176,22 @@ func (suite *SyncEntityTestSuite) TestHasServerDefinedUniqueTag() { tag2.ServerDefinedUniqueTag = aws.String("tag2") entities := []*datastore.SyncEntity{&tag1, &tag2} - err := suite.dynamo.InsertSyncEntitiesWithServerTags(entities) + err := suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities) suite.Require().NoError(err, "Insert sync entities should succeed") - hasTag, err := suite.dynamo.HasServerDefinedUniqueTag("client1", "tag1") + hasTag, err := suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client1", "tag1") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.Assert().Equal(hasTag, true) - hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client1", "tag2") + hasTag, err = suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client1", "tag2") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.Assert().Equal(hasTag, false) - hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client2", "tag1") + hasTag, err = suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client2", "tag1") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.Assert().Equal(hasTag, false) - hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client2", "tag2") + hasTag, err = suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client2", "tag2") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.Assert().Equal(hasTag, true) } @@ -213,24 +214,24 @@ func (suite *SyncEntityTestSuite) TestHasItem() { entity2.ClientID = "client2" entity2.ID = "id2" - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - hasTag, err := suite.dynamo.HasItem("client1", "id1") + hasTag, err := suite.dynamo.HasItem(context.Background(), "client1", "id1") suite.Require().NoError(err, "HasItem should succeed") suite.Assert().Equal(hasTag, true) - hasTag, err = suite.dynamo.HasItem("client2", "id2") + hasTag, err = suite.dynamo.HasItem(context.Background(), "client2", "id2") suite.Require().NoError(err, "HasItem should succeed") suite.Assert().Equal(hasTag, true) - hasTag, err = suite.dynamo.HasItem("client2", "id3") + hasTag, err = suite.dynamo.HasItem(context.Background(), "client2", "id3") suite.Require().NoError(err, "HasItem should succeed") suite.Assert().Equal(hasTag, false) - hasTag, err = suite.dynamo.HasItem("client3", "id2") + hasTag, err = suite.dynamo.HasItem(context.Background(), "client3", "id2") suite.Require().NoError(err, "HasItem should succeed") suite.Assert().Equal(hasTag, false) } @@ -253,7 +254,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { entity2.ID = "id2" entities := []*datastore.SyncEntity{&entity1, &entity2} suite.Require().Error( - suite.dynamo.InsertSyncEntitiesWithServerTags(entities), + suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities), "Insert with same ClientID and server tag would fail") // Check nothing is written to DB when it fails. @@ -270,7 +271,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { entity3.ID = "id3" entities = []*datastore.SyncEntity{&entity1, &entity2, &entity3} suite.Require().NoError( - suite.dynamo.InsertSyncEntitiesWithServerTags(entities), + suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities), "InsertSyncEntitiesWithServerTags should succeed") // Scan DB and check all items are saved @@ -319,11 +320,11 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { entity2.ID = "id2" entity3 := entity1 entity3.ID = "id3" - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity3) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity3) suite.Require().NoError(err, "InsertSyncEntity should succeed") // Check sync entities are inserted correctly in DB. syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) @@ -338,7 +339,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity1.Deleted = aws.Bool(true) updateEntity1.DataTypeMtime = aws.String("123#23456789") updateEntity1.Specifics = []byte{3, 4} - conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) + conflict, deleted, err := suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().True(deleted, "Delete operation should return true") @@ -352,7 +353,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity2.ParentID = aws.String("parentID") updateEntity2.Name = aws.String("name") updateEntity2.NonUniqueName = aws.String("non_unique_name") - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, *entity2.Version) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity2, *entity2.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().False(deleted, "Non-delete operation should return false") @@ -362,7 +363,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity3.ID = "id3" updateEntity3.Folder = nil updateEntity3.Deleted = nil - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity3, *entity3.Version) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity3, *entity3.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().False(deleted, "Non-delete operation should return false") @@ -372,7 +373,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { // Update entity again with the wrong old version as (version mismatch) // should return false. - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, 12345678) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity2, 12345678) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().True(conflict, "Update with the same version should return conflict") suite.Assert().False(deleted, "Conflict operation should return false for delete") @@ -398,7 +399,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { DataTypeMtime: aws.String("123#12345678"), Specifics: []byte{1, 2}, } - conflict, err := suite.dynamo.InsertSyncEntity(&entity1) + conflict, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") suite.Assert().False(conflict, "Successful insert should not have conflict") @@ -406,7 +407,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { updateEntity1.Version = aws.Int64(2) updateEntity1.Folder = aws.Bool(true) updateEntity1.Mtime = aws.Int64(24242424) - conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, 1) + conflict, deleted, err := suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, 1) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().False(deleted, "Non-delete operation should return false") @@ -415,7 +416,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { // since the version number should be ignored updateEntity2 := updateEntity1 updateEntity2.Mtime = aws.Int64(42424242) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, 1) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity2, 1) suite.Require().NoError(err, "UpdateSyncEntity should not return an error") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().False(deleted, "Non-delete operation should return false") @@ -423,7 +424,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { updateEntity3 := entity1 updateEntity3.Deleted = aws.Bool(true) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity3, 1) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity3, 1) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().True(deleted, "Delete operation should return true") @@ -448,7 +449,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { DataTypeMtime: aws.String("123#12345678"), Specifics: []byte{1, 2}, } - conflict, err := suite.dynamo.InsertSyncEntity(&entity1) + conflict, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") suite.Assert().False(conflict, "Successful insert should not have conflict") @@ -464,21 +465,21 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { updateEntity1.Folder = aws.Bool(true) updateEntity1.DataTypeMtime = aws.String("123#23456789") updateEntity1.Specifics = []byte{3, 4} - conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) + conflict, deleted, err := suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().False(deleted, "Non-delete operation should return false") // Soft-delete the item with wrong version should get conflict. updateEntity1.Deleted = aws.Bool(true) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().True(conflict, "Version mismatched update should have conflict") suite.Assert().False(deleted, "Failed delete operation should return false") // Soft-delete the item with matched version. updateEntity1.Version = aws.Int64(34567890) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity1, 23456789) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, 23456789) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.Assert().False(conflict, "Successful update should not have conflict") suite.Assert().True(deleted, "Delete operation should return true") @@ -491,7 +492,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { // Insert another item with the same client tag again. entity2 := entity1 entity2.ID = "id2" - conflict, err = suite.dynamo.InsertSyncEntity(&entity2) + conflict, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") suite.Assert().False(conflict, "Successful insert should not have conflict") @@ -539,63 +540,63 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { entity5.ID = "id5" entity5.ExpirationTime = aws.Int64(time.Now().Unix() - 300) - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity3) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity3) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity4) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity4) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity5) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity5) suite.Require().NoError(err, "InsertSyncEntity should succeed") // Get all updates for type 123 and client1 using token = 0. - hasChangesRemaining, syncItems, err := suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 100) + hasChangesRemaining, syncItems, err := suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1, entity2}) suite.Assert().False(hasChangesRemaining) // Get all updates for type 124 and client1 using token = 0. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(124, 0, true, "client1", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 124, 0, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity3}) suite.Assert().False(hasChangesRemaining) // Get all updates for type 123 and client2 using token = 0. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client2", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client2", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity4}) suite.Assert().False(hasChangesRemaining) // Get all updates for type 124 and client2 using token = 0. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(124, 0, true, "client2", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 124, 0, true, "client2", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(len(syncItems), 0) suite.Assert().False(hasChangesRemaining) // Test maxSize will limit the return entries size, and hasChangesRemaining // should be true when there are more updates available in the DB. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 1) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 1) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1}) suite.Assert().True(hasChangesRemaining) // Test when num of query items equal to the limit, hasChangesRemaining should // be true. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 2) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 2) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity1, entity2}) suite.Assert().True(hasChangesRemaining) // Test fetchFolders will remove folder items if false - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, false, "client1", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, false, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity2}) suite.Assert().False(hasChangesRemaining) // Get all updates for a type for a client using mtime of one item as token. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 12345678, true, "client1", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 12345678, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, []datastore.SyncEntity{entity2}) suite.Assert().False(hasChangesRemaining) @@ -622,13 +623,13 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { entity.ID = "id" + strconv.Itoa(i) entity.Mtime = aws.Int64(mtime) entity.DataTypeMtime = aws.String("123#" + strconv.FormatInt(*entity.Mtime, 10)) - _, err := suite.dynamo.InsertSyncEntity(&entity) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity) suite.Require().NoError(err, "InsertSyncEntity should succeed") expectedSyncItems = append(expectedSyncItems, entity) } // All items should be returned and sorted by Mtime. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 300) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 300) suite.Require().NoError(err, "GetUpdatesForType should succeed") sort.Sort(datastore.SyncEntityByMtime(expectedSyncItems)) suite.Assert().Equal(syncItems, expectedSyncItems) @@ -636,7 +637,7 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { // Test that when maxGUBatchSize is smaller than total updates, the first n // items ordered by Mtime should be returned. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 200) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 200) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Assert().Equal(syncItems, expectedSyncItems[0:200]) suite.Assert().True(hasChangesRemaining) @@ -847,7 +848,7 @@ func (suite *SyncEntityTestSuite) TestCreatePBSyncEntity() { func (suite *SyncEntityTestSuite) TestDisableSyncChain() { clientID := "client1" id := "disabled_chain" - err := suite.dynamo.DisableSyncChain(clientID) + err := suite.dynamo.DisableSyncChain(context.Background(), clientID) suite.Require().NoError(err, "DisableSyncChain should succeed") e, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") @@ -859,13 +860,13 @@ func (suite *SyncEntityTestSuite) TestDisableSyncChain() { func (suite *SyncEntityTestSuite) TestIsSyncChainDisabled() { clientID := "client1" - disabled, err := suite.dynamo.IsSyncChainDisabled(clientID) + disabled, err := suite.dynamo.IsSyncChainDisabled(context.Background(), clientID) suite.Require().NoError(err, "IsSyncChainDisabled should succeed") suite.Assert().Equal(false, disabled) - err = suite.dynamo.DisableSyncChain(clientID) + err = suite.dynamo.DisableSyncChain(context.Background(), clientID) suite.Require().NoError(err, "DisableSyncChain should succeed") - disabled, err = suite.dynamo.IsSyncChainDisabled(clientID) + disabled, err = suite.dynamo.IsSyncChainDisabled(context.Background(), clientID) suite.Require().NoError(err, "IsSyncChainDisabled should succeed") suite.Assert().Equal(true, disabled) } @@ -883,14 +884,14 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { Deleted: aws.Bool(false), DataTypeMtime: aws.String("123#12345678"), } - _, err := suite.dynamo.InsertSyncEntity(&entity) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity) suite.Require().NoError(err, "InsertSyncEntity should succeed") e, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") suite.Assert().Equal(1, len(e)) - e, err = suite.dynamo.ClearServerData(entity.ClientID) + e, err = suite.dynamo.ClearServerData(context.Background(), entity.ClientID) suite.Require().NoError(err, "ClearServerData should succeed") suite.Assert().Equal(1, len(e)) @@ -916,7 +917,7 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { entity2.ServerDefinedUniqueTag = aws.String("tag2") entities := []*datastore.SyncEntity{&entity1, &entity2} suite.Require().NoError( - suite.dynamo.InsertSyncEntitiesWithServerTags(entities), + suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities), "InsertSyncEntitiesWithServerTags should succeed") e, err = datastoretest.ScanSyncEntities(suite.dynamo) @@ -927,7 +928,7 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { suite.Require().NoError(err, "ScanTagItems should succeed") suite.Assert().Equal(2, len(t), "No items should be written if fail") - e, err = suite.dynamo.ClearServerData(entity.ClientID) + e, err = suite.dynamo.ClearServerData(context.Background(), entity.ClientID) suite.Require().NoError(err, "ClearServerData should succeed") suite.Assert().Equal(4, len(e)) diff --git a/middleware/disabled_chain.go b/middleware/disabled_chain.go index 282e5382..a9e5d7a6 100644 --- a/middleware/disabled_chain.go +++ b/middleware/disabled_chain.go @@ -27,7 +27,7 @@ func DisabledChain(next http.Handler) http.Handler { return } - disabled, err := db.IsSyncChainDisabled(clientID) + disabled, err := db.IsSyncChainDisabled(ctx, clientID) if err != nil { http.Error(w, "unable to complete request", http.StatusInternalServerError) From 40f3782e079c5338fe3e4257f6ec2a48c7c14b6f Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Fri, 24 Oct 2025 03:45:19 -0500 Subject: [PATCH 03/11] fix test ctx --- middleware/middleware_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 277f5236..1e0f0ac3 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -25,9 +25,9 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // Active Chain datastore := new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(false, nil) ctx := context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) + datastore.On("IsSyncChainDisabled", ctx, clientID).Return(false, nil) next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) handler := middleware.DisabledChain(next) req, err := http.NewRequestWithContext(ctx, "POST", "v2/command/", bytes.NewBuffer([]byte{})) @@ -38,9 +38,9 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // Disabled chain datastore = new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(true, nil) ctx = context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) + datastore.On("IsSyncChainDisabled", ctx, clientID).Return(true, nil) next = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { suite.Require().Equal(false, true) }) @@ -53,12 +53,11 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // DB error datastore = new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(false, fmt.Errorf("unable to query db")) ctx = context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) + datastore.On("IsSyncChainDisabled", ctx, clientID).Return(false, fmt.Errorf("unable to query db")) next = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) handler = middleware.DisabledChain(next) - rr = httptest.NewRecorder() req, err = http.NewRequestWithContext(ctx, "POST", "v2/command/", bytes.NewBuffer([]byte{})) suite.Require().NoError(err, "NewRequestWithContext should succeed") rr = httptest.NewRecorder() From 8d35a2177145766a7a6cc0cfcabf0f32028bac18 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Fri, 24 Oct 2025 18:38:20 -0500 Subject: [PATCH 04/11] fix --- command/command.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/command/command.go b/command/command.go index 2533c070..c155ce99 100644 --- a/command/command.go +++ b/command/command.go @@ -130,7 +130,7 @@ func handleGetUpdatesRequest(ctx context.Context, cache *cache.Cache, guMsg *syn // Check cache to short circuit with 0 updates for polling requests. if isPoll && - !cache.IsTypeMtimeUpdated(context.Background(), clientID, int(*fromProgressMarker.DataTypeId), token) { + !cache.IsTypeMtimeUpdated(ctx, clientID, int(*fromProgressMarker.DataTypeId), token) { continue } @@ -187,7 +187,7 @@ func handleGetUpdatesRequest(ctx context.Context, cache *cache.Cache, guMsg *syn } else { mtime = *entities[j-1].Mtime } - cache.SetTypeMtime(context.Background(), clientID, int(*fromProgressMarker.DataTypeId), mtime) + cache.SetTypeMtime(ctx, clientID, int(*fromProgressMarker.DataTypeId), mtime) } } @@ -199,19 +199,19 @@ func getItemCounts(ctx context.Context, cache *cache.Cache, db datastore.Datasto if err != nil { return nil, 0, 0, err } - newNormalCount, newHistoryCount, err := getInterimItemCounts(cache, clientID, false) + newNormalCount, newHistoryCount, err := getInterimItemCounts(ctx, cache, clientID, false) if err != nil { return nil, 0, 0, err } return itemCounts, newNormalCount, newHistoryCount, nil } -func getInterimItemCounts(cache *cache.Cache, clientID string, clearCache bool) (int, int, error) { - newNormalCount, err := cache.GetInterimCount(context.Background(), clientID, normalCountTypeStr, clearCache) +func getInterimItemCounts(ctx context.Context, cache *cache.Cache, clientID string, clearCache bool) (int, int, error) { + newNormalCount, err := cache.GetInterimCount(ctx, clientID, normalCountTypeStr, clearCache) if err != nil { return 0, 0, err } - newHistoryCount, err := cache.GetInterimCount(context.Background(), clientID, historyCountTypeStr, clearCache) + newHistoryCount, err := cache.GetInterimCount(ctx, clientID, historyCountTypeStr, clearCache) if err != nil { return 0, 0, err } @@ -324,9 +324,9 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } if isHistoryRelatedItem { - newHistoryCount, err = cache.IncrementInterimCount(context.Background(), clientID, historyCountTypeStr, false) + newHistoryCount, _ = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, false) } else { - newNormalCount, err = cache.IncrementInterimCount(context.Background(), clientID, normalCountTypeStr, false) + newNormalCount, _ = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, false) } } } else { // Update @@ -345,9 +345,9 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } if deleted { if isHistoryRelatedItem { - newHistoryCount, err = cache.IncrementInterimCount(context.Background(), clientID, historyCountTypeStr, true) + newHistoryCount, _ = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, true) } else { - newNormalCount, err = cache.IncrementInterimCount(context.Background(), clientID, normalCountTypeStr, true) + newNormalCount, _ = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, true) } } } @@ -366,7 +366,7 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn entryRsp.Mtime = entityToCommit.Mtime } - newNormalCount, newHistoryCount, err = getInterimItemCounts(cache, clientID, true) + newNormalCount, newHistoryCount, err = getInterimItemCounts(ctx, cache, clientID, true) if err != nil { log.Error().Err(err).Msg("Get interim item counts failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -375,7 +375,7 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn // Save (clientID#dataType, mtime) into cache after writing into DB. for dataType, mtime := range typeMtimeMap { - cache.SetTypeMtime(context.Background(), clientID, dataType, mtime) + cache.SetTypeMtime(ctx, clientID, dataType, mtime) } err = db.UpdateClientItemCount(ctx, itemCounts, newNormalCount, newHistoryCount) @@ -420,7 +420,7 @@ func handleClearServerDataRequest(ctx context.Context, cache *cache.Cache, db da } if len(typeMtimeCacheKeys) > 0 { - err = cache.Del(context.Background(), typeMtimeCacheKeys...) + err = cache.Del(ctx, typeMtimeCacheKeys...) if err != nil { log.Error().Err(err).Msg("Failed to clear cache") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR From 8c939cac034d2ceb31277554135f8abb0a0c42ae Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:19:23 -0500 Subject: [PATCH 05/11] fix: use context.Background for handleClearServerDataRequest --- command/command.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/command/command.go b/command/command.go index c155ce99..c06a51ee 100644 --- a/command/command.go +++ b/command/command.go @@ -475,7 +475,7 @@ func HandleClientToServerMessage(ctx context.Context, cache *cache.Cache, pb *sy } else if *pb.MessageContents == sync_pb.ClientToServerMessage_CLEAR_SERVER_DATA { csdRsp := &sync_pb.ClearServerDataResponse{} pbRsp.ClearServerData = csdRsp - pbRsp.ErrorCode, err = handleClearServerDataRequest(ctx, cache, db, pb.ClearServerData, clientID) + pbRsp.ErrorCode, err = handleClearServerDataRequest(context.Background(), cache, db, pb.ClearServerData, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) From 49922f7a87b2399bde5af0deed80a82c87469810 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:32:06 -0500 Subject: [PATCH 06/11] gowrap --- cache/instrumented_redis.go | 26 +++++++++++++------------- datastore/instrumented_datastore.go | 12 ++++++------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cache/instrumented_redis.go b/cache/instrumented_redis.go index 032d2f1f..ad9aa14b 100755 --- a/cache/instrumented_redis.go +++ b/cache/instrumented_redis.go @@ -1,10 +1,10 @@ -package cache +// Code generated by gowrap. DO NOT EDIT. +// template: ../.prom-gowrap.tmpl +// gowrap: http://github.com/hexdigest/gowrap -// DO NOT EDIT! -// This code is generated with http://github.com/hexdigest/gowrap tool -// using ../.prom-gowrap.tmpl template +package cache -//go:generate gowrap gen -p github.com/brave/go-sync/cache -i RedisClient -t ../.prom-gowrap.tmpl -o instrumented_redis.go +//go:generate gowrap gen -p github.com/brave/go-sync/cache -i RedisClient -t ../.prom-gowrap.tmpl -o instrumented_redis.go -l "" import ( "context" @@ -80,8 +80,8 @@ func (_d RedisClientWithPrometheus) Get(ctx context.Context, key string, deleteA return _d.base.Get(ctx, key, deleteAfterGet) } -// Set implements RedisClient -func (_d RedisClientWithPrometheus) Set(ctx context.Context, key string, val string, ttl time.Duration) (err error) { +// Incr implements RedisClient +func (_d RedisClientWithPrometheus) Incr(ctx context.Context, key string, subtract bool) (i1 int, err error) { _since := time.Now() defer func() { result := "ok" @@ -89,13 +89,13 @@ func (_d RedisClientWithPrometheus) Set(ctx context.Context, key string, val str result = "error" } - redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Set", result).Observe(time.Since(_since).Seconds()) + redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Incr", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.Set(ctx, key, val, ttl) + return _d.base.Incr(ctx, key, subtract) } -// Incr implements RedisClient -func (_d RedisClientWithPrometheus) Incr(ctx context.Context, key string, subtract bool) (val int, err error) { +// Set implements RedisClient +func (_d RedisClientWithPrometheus) Set(ctx context.Context, key string, val string, ttl time.Duration) (err error) { _since := time.Now() defer func() { result := "ok" @@ -103,7 +103,7 @@ func (_d RedisClientWithPrometheus) Incr(ctx context.Context, key string, subtra result = "error" } - redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Incr", result).Observe(time.Since(_since).Seconds()) + redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Set", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.Incr(ctx, key, subtract) + return _d.base.Set(ctx, key, val, ttl) } diff --git a/datastore/instrumented_datastore.go b/datastore/instrumented_datastore.go index 9765ed26..2abecd13 100644 --- a/datastore/instrumented_datastore.go +++ b/datastore/instrumented_datastore.go @@ -1,10 +1,10 @@ -package datastore +// Code generated by gowrap. DO NOT EDIT. +// template: ../.prom-gowrap.tmpl +// gowrap: http://github.com/hexdigest/gowrap -// DO NOT EDIT! -// This code is generated with http://github.com/hexdigest/gowrap tool -// using ../.prom-gowrap.tmpl template +package datastore -//go:generate gowrap gen -p github.com/brave/go-sync/datastore -i Datastore -t ../.prom-gowrap.tmpl -o instrumented_datastore.go +//go:generate gowrap gen -p github.com/brave/go-sync/datastore -i Datastore -t ../.prom-gowrap.tmpl -o instrumented_datastore.go -l "" import ( "context" @@ -67,7 +67,7 @@ func (_d DatastoreWithPrometheus) DisableSyncChain(ctx context.Context, clientID } // GetClientItemCount implements Datastore -func (_d DatastoreWithPrometheus) GetClientItemCount(ctx context.Context, clientID string) (counts *ClientItemCounts, err error) { +func (_d DatastoreWithPrometheus) GetClientItemCount(ctx context.Context, clientID string) (cp1 *ClientItemCounts, err error) { _since := time.Now() defer func() { result := "ok" From bad39a1a8137e0c607be4b6914e3438dff8de60e Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:22:01 -0500 Subject: [PATCH 07/11] fix: shadowed err var --- command/command.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/command/command.go b/command/command.go index c06a51ee..84f70b1e 100644 --- a/command/command.go +++ b/command/command.go @@ -293,6 +293,7 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } } + var interimErr error if !isUpdateOp { // Create if currentNormalItemCount+currentHistoryItemCount+newNormalCount+newHistoryCount >= maxClientObjectQuota+boostedQuotaAddition { rspType := sync_pb.CommitResponse_OVER_QUOTA @@ -324,9 +325,9 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } if isHistoryRelatedItem { - newHistoryCount, _ = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, false) + newHistoryCount, interimErr = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, false) } else { - newNormalCount, _ = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, false) + newNormalCount, interimErr = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, false) } } } else { // Update @@ -345,16 +346,16 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } if deleted { if isHistoryRelatedItem { - newHistoryCount, _ = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, true) + newHistoryCount, interimErr = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, true) } else { - newNormalCount, _ = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, true) + newNormalCount, interimErr = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, true) } } } - if err != nil { - log.Error().Err(err).Msg("Interim count update failed") + if interimErr != nil { + log.Error().Err(interimErr).Msg("Interim count update failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR - return &errCode, fmt.Errorf("Interim count update failed: %w", err) + return &errCode, fmt.Errorf("Interim count update failed: %w", interimErr) } typeMtimeMap[*entityToCommit.DataType] = *entityToCommit.Mtime From 678f584390de60de3702b433ea3b9b8589ff1345 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Wed, 29 Oct 2025 01:45:35 -0500 Subject: [PATCH 08/11] make SetTypeMtime optimistic --- command/command.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/command/command.go b/command/command.go index 84f70b1e..338a8417 100644 --- a/command/command.go +++ b/command/command.go @@ -255,8 +255,6 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn // Map client-generated ID to its server-generated ID. idMap := make(map[string]string) - // Map to save commit data type ID & mtime - typeMtimeMap := make(map[int]int64) for i, v := range commitMsg.Entries { entryRsp := &sync_pb.CommitResponse_EntryResponse{} commitRsp.Entryresponse[i] = entryRsp @@ -281,6 +279,7 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn isUpdateOp := oldVersion != 0 isHistoryRelatedItem := *entityToCommit.DataType == datastore.HistoryTypeID || *entityToCommit.DataType == datastore.HistoryDeleteDirectiveTypeID *entityToCommit.Version = *entityToCommit.Mtime + if *entityToCommit.DataType == datastore.HistoryTypeID { // Check if item exists using client_unique_tag isUpdateOp, err = db.HasItem(ctx, clientID, *entityToCommit.ClientDefinedUniqueTag) @@ -293,6 +292,9 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } } + // Optimistically update cache with the latest mtime for this data type. + cache.SetTypeMtime(ctx, clientID, *entityToCommit.DataType, *entityToCommit.Mtime) + var interimErr error if !isUpdateOp { // Create if currentNormalItemCount+currentHistoryItemCount+newNormalCount+newHistoryCount >= maxClientObjectQuota+boostedQuotaAddition { @@ -358,7 +360,6 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn return &errCode, fmt.Errorf("Interim count update failed: %w", interimErr) } - typeMtimeMap[*entityToCommit.DataType] = *entityToCommit.Mtime // Prepare success response rspType := sync_pb.CommitResponse_SUCCESS entryRsp.ResponseType = &rspType @@ -374,11 +375,6 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn return &errCode, fmt.Errorf("error getting interim item count: %w", err) } - // Save (clientID#dataType, mtime) into cache after writing into DB. - for dataType, mtime := range typeMtimeMap { - cache.SetTypeMtime(ctx, clientID, dataType, mtime) - } - err = db.UpdateClientItemCount(ctx, itemCounts, newNormalCount, newHistoryCount) if err != nil { // We only impose a soft quota limit on the item count for each client, so From 861d0133c243cf68c2b23a68a46ce97b5caeae91 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Tue, 2 Dec 2025 03:31:24 -0500 Subject: [PATCH 09/11] fix: use context.TODO() for handleCommitRequest --- command/command.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/command/command.go b/command/command.go index 72745b05..5c92ae72 100644 --- a/command/command.go +++ b/command/command.go @@ -460,7 +460,7 @@ func HandleClientToServerMessage(ctx context.Context, cache *cache.Cache, pb *sy } else if *pb.MessageContents == sync_pb.ClientToServerMessage_COMMIT { commitRsp := &sync_pb.CommitResponse{} pbRsp.Commit = commitRsp - pbRsp.ErrorCode, err = handleCommitRequest(ctx, cache, pb.Commit, commitRsp, db, clientID) + pbRsp.ErrorCode, err = handleCommitRequest(context.TODO(), cache, pb.Commit, commitRsp, db, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) From afcbf3ea6d958d143c3dc72c4d6c33310748fc17 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Thu, 4 Dec 2025 00:09:43 -0500 Subject: [PATCH 10/11] Revert "make SetTypeMtime optimistic" This reverts commit 678f584390de60de3702b433ea3b9b8589ff1345. --- command/command.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/command/command.go b/command/command.go index 5c92ae72..63899bce 100644 --- a/command/command.go +++ b/command/command.go @@ -257,6 +257,8 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn // Map client-generated ID to its server-generated ID. idMap := make(map[string]string) + // Map to save commit data type ID & mtime + typeMtimeMap := make(map[int]int64) for i, v := range commitMsg.Entries { entryRsp := &sync_pb.CommitResponse_EntryResponse{} commitRsp.Entryresponse[i] = entryRsp @@ -281,7 +283,6 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn isUpdateOp := oldVersion != 0 isHistoryRelatedItem := *entityToCommit.DataType == datastore.HistoryTypeID || *entityToCommit.DataType == datastore.HistoryDeleteDirectiveTypeID *entityToCommit.Version = *entityToCommit.Mtime - if *entityToCommit.DataType == datastore.HistoryTypeID { // Check if item exists using client_unique_tag isUpdateOp, err = db.HasItem(ctx, clientID, *entityToCommit.ClientDefinedUniqueTag) @@ -294,9 +295,6 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn } } - // Optimistically update cache with the latest mtime for this data type. - cache.SetTypeMtime(ctx, clientID, *entityToCommit.DataType, *entityToCommit.Mtime) - var interimErr error if !isUpdateOp { // Create if currentNormalItemCount+currentHistoryItemCount+newNormalCount+newHistoryCount >= maxClientObjectQuota+boostedQuotaAddition { @@ -362,6 +360,7 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn return &errCode, fmt.Errorf("interim count update failed: %w", interimErr) } + typeMtimeMap[*entityToCommit.DataType] = *entityToCommit.Mtime // Prepare success response rspType := sync_pb.CommitResponse_SUCCESS entryRsp.ResponseType = &rspType @@ -377,6 +376,11 @@ func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *syn return &errCode, fmt.Errorf("error getting interim item count: %w", err) } + // Save (clientID#dataType, mtime) into cache after writing into DB. + for dataType, mtime := range typeMtimeMap { + cache.SetTypeMtime(ctx, clientID, dataType, mtime) + } + err = db.UpdateClientItemCount(ctx, itemCounts, newNormalCount, newHistoryCount) if err != nil { // We only impose a soft quota limit on the item count for each client, so From 7ec956968d22a6bbd0193a0b34c045873d6a8ee0 Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Wed, 21 Jan 2026 03:06:44 -0500 Subject: [PATCH 11/11] fix tests --- command/command_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/command/command_test.go b/command/command_test.go index f3c12a03..1f8bb3c9 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -409,7 +409,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceed rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, testCase.clientID), "HandleClientToServerMessage should succeed for device %d", i) suite.Equal(sync_pb.SyncEnums_SUCCESS, *rsp.ErrorCode, "device %d should succeed", i) suite.NotNil(rsp.GetUpdates, "device %d should have GetUpdates response", i) @@ -419,7 +419,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceed commitMsg := getClientToServerCommitMsg([]*sync_pb.SyncEntity{deviceEntry}) commitRsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, commitMsg, commitRsp, suite.dynamo, testCase.clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, commitMsg, commitRsp, suite.dynamo, testCase.clientID), "Commit device info should succeed for device %d", i) suite.Equal(sync_pb.SyncEnums_SUCCESS, *commitRsp.ErrorCode, "Commit device info should succeed for device %d", i) } @@ -427,7 +427,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceed // should get THROTTLED error when device limit is exceeded rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, testCase.clientID), "HandleClientToServerMessage should succeed") suite.Equal(sync_pb.SyncEnums_THROTTLED, *rsp.ErrorCode, "errorCode should be THROTTLED") suite.Require().NotNil(rsp.ErrorMessage, "error message should be present")