From a11a468856950638f38e2b7cd033a48da49f684f Mon Sep 17 00:00:00 2001 From: mschfh <37435502+mschfh@users.noreply.github.com> Date: Wed, 21 Jan 2026 13:10:03 -0500 Subject: [PATCH] fix: improve context propagation (#374) Co-authored-by: Darnell Andries --- cache/instrumented_redis.go | 26 ++-- command/command.go | 72 +++++------ command/command_test.go | 58 ++++----- command/server_defined_unique_entity.go | 7 +- command/server_defined_unique_entity_test.go | 7 +- controller/controller.go | 2 +- datastore/datastore.go | 24 ++-- datastore/datastoretest/mock_datastore.go | 46 +++---- datastore/instrumented_datastore.go | 55 ++++----- datastore/item_count.go | 16 +-- datastore/item_count_test.go | 11 +- datastore/sync_entity.go | 44 +++---- datastore/sync_entity_test.go | 121 ++++++++++--------- middleware/disabled_chain.go | 2 +- middleware/middleware_test.go | 6 +- 15 files changed, 253 insertions(+), 244 deletions(-) diff --git a/cache/instrumented_redis.go b/cache/instrumented_redis.go index 032d2f1f..ad9aa14b 100755 --- a/cache/instrumented_redis.go +++ b/cache/instrumented_redis.go @@ -1,10 +1,10 @@ -package cache +// Code generated by gowrap. DO NOT EDIT. +// template: ../.prom-gowrap.tmpl +// gowrap: http://github.com/hexdigest/gowrap -// DO NOT EDIT! -// This code is generated with http://github.com/hexdigest/gowrap tool -// using ../.prom-gowrap.tmpl template +package cache -//go:generate gowrap gen -p github.com/brave/go-sync/cache -i RedisClient -t ../.prom-gowrap.tmpl -o instrumented_redis.go +//go:generate gowrap gen -p github.com/brave/go-sync/cache -i RedisClient -t ../.prom-gowrap.tmpl -o instrumented_redis.go -l "" import ( "context" @@ -80,8 +80,8 @@ func (_d RedisClientWithPrometheus) Get(ctx context.Context, key string, deleteA return _d.base.Get(ctx, key, deleteAfterGet) } -// Set implements RedisClient -func (_d RedisClientWithPrometheus) Set(ctx context.Context, key string, val string, ttl time.Duration) (err error) { +// Incr implements RedisClient +func (_d RedisClientWithPrometheus) Incr(ctx context.Context, key string, subtract bool) (i1 int, err error) { _since := time.Now() defer func() { result := "ok" @@ -89,13 +89,13 @@ func (_d RedisClientWithPrometheus) Set(ctx context.Context, key string, val str result = "error" } - redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Set", result).Observe(time.Since(_since).Seconds()) + redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Incr", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.Set(ctx, key, val, ttl) + return _d.base.Incr(ctx, key, subtract) } -// Incr implements RedisClient -func (_d RedisClientWithPrometheus) Incr(ctx context.Context, key string, subtract bool) (val int, err error) { +// Set implements RedisClient +func (_d RedisClientWithPrometheus) Set(ctx context.Context, key string, val string, ttl time.Duration) (err error) { _since := time.Now() defer func() { result := "ok" @@ -103,7 +103,7 @@ func (_d RedisClientWithPrometheus) Incr(ctx context.Context, key string, subtra result = "error" } - redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Incr", result).Observe(time.Since(_since).Seconds()) + redisclientDurationSummaryVec.WithLabelValues(_d.instanceName, "Set", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.Incr(ctx, key, subtract) + return _d.base.Set(ctx, key, val, ttl) } diff --git a/command/command.go b/command/command.go index 8b8b541b..121c7647 100644 --- a/command/command.go +++ b/command/command.go @@ -34,7 +34,7 @@ const ( // handleGetUpdatesRequest handles GetUpdatesMessage and fills // GetUpdatesResponse. Target sync entities in the database will be updated or // deleted based on the client's requests. -func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessage, guRsp *sync_pb.GetUpdatesResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { +func handleGetUpdatesRequest(ctx context.Context, cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessage, guRsp *sync_pb.GetUpdatesResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { errCode := sync_pb.SyncEnums_SUCCESS // default value, might be changed later isNewClient := guMsg.GetUpdatesOrigin != nil && *guMsg.GetUpdatesOrigin == sync_pb.SyncEnums_NEW_CLIENT isPoll := guMsg.GetUpdatesOrigin != nil && *guMsg.GetUpdatesOrigin == sync_pb.SyncEnums_PERIODIC @@ -42,7 +42,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag // Reject the request if client has >= 50 devices in the chain. activeDevices := 0 for { - hasChangesRemaining, syncEntities, err := db.GetUpdatesForType(deviceInfoTypeID, 0, false, clientID, int64(maxGUBatchSize)) + hasChangesRemaining, syncEntities, err := db.GetUpdatesForType(ctx, deviceInfoTypeID, 0, false, clientID, int64(maxGUBatchSize)) if err != nil { log.Error().Err(err).Msgf("db.GetUpdatesForType failed for type %v", deviceInfoTypeID) errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -69,7 +69,7 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag } // Insert initial records if needed. - err := InsertServerDefinedUniqueEntities(db, clientID) + err := InsertServerDefinedUniqueEntities(ctx, db, clientID) if err != nil { log.Error().Err(err).Msg("Create server defined unique entities failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -131,12 +131,12 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag // Check cache to short circuit with 0 updates for polling requests. if isPoll && - !cache.IsTypeMtimeUpdated(context.Background(), clientID, int(*fromProgressMarker.DataTypeId), token) { + !cache.IsTypeMtimeUpdated(ctx, clientID, int(*fromProgressMarker.DataTypeId), token) { continue } curMaxSize := int64(maxSize) - int64(len(guRsp.Entries)) - hasChangesRemaining, entities, err := db.GetUpdatesForType(int(*fromProgressMarker.DataTypeId), token, fetchFolders, clientID, curMaxSize) + hasChangesRemaining, entities, err := db.GetUpdatesForType(ctx, int(*fromProgressMarker.DataTypeId), token, fetchFolders, clientID, curMaxSize) if err != nil { log.Error().Err(err).Msgf("db.GetUpdatesForType failed for type %v", *fromProgressMarker.DataTypeId) errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -188,31 +188,31 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag } else { mtime = *entities[j-1].Mtime } - cache.SetTypeMtime(context.Background(), clientID, int(*fromProgressMarker.DataTypeId), mtime) + cache.SetTypeMtime(ctx, clientID, int(*fromProgressMarker.DataTypeId), mtime) } } return &errCode, nil } -func getItemCounts(cache *cache.Cache, db datastore.Datastore, clientID string) (*datastore.ClientItemCounts, int, int, error) { - itemCounts, err := db.GetClientItemCount(clientID) +func getItemCounts(ctx context.Context, cache *cache.Cache, db datastore.Datastore, clientID string) (*datastore.ClientItemCounts, int, int, error) { + itemCounts, err := db.GetClientItemCount(ctx, clientID) if err != nil { return nil, 0, 0, err } - newNormalCount, newHistoryCount, err := getInterimItemCounts(cache, clientID, false) + newNormalCount, newHistoryCount, err := getInterimItemCounts(ctx, cache, clientID, false) if err != nil { return nil, 0, 0, err } return itemCounts, newNormalCount, newHistoryCount, nil } -func getInterimItemCounts(cache *cache.Cache, clientID string, clearCache bool) (int, int, error) { - newNormalCount, err := cache.GetInterimCount(context.Background(), clientID, normalCountTypeStr, clearCache) +func getInterimItemCounts(ctx context.Context, cache *cache.Cache, clientID string, clearCache bool) (int, int, error) { + newNormalCount, err := cache.GetInterimCount(ctx, clientID, normalCountTypeStr, clearCache) if err != nil { return 0, 0, err } - newHistoryCount, err := cache.GetInterimCount(context.Background(), clientID, historyCountTypeStr, clearCache) + newHistoryCount, err := cache.GetInterimCount(ctx, clientID, historyCountTypeStr, clearCache) if err != nil { return 0, 0, err } @@ -223,7 +223,7 @@ func getInterimItemCounts(cache *cache.Cache, clientID string, clearCache bool) // For each commit entry: // - new sync entity is created and inserted into the database if version is 0. // - existed sync entity will be updated if version is greater than 0. -func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, commitRsp *sync_pb.CommitResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { +func handleCommitRequest(ctx context.Context, cache *cache.Cache, commitMsg *sync_pb.CommitMessage, commitRsp *sync_pb.CommitResponse, db datastore.Datastore, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { if commitMsg == nil { return nil, errors.New("nil commitMsg is received") } @@ -233,7 +233,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c return &errCode, nil } - itemCounts, newNormalCount, newHistoryCount, err := getItemCounts(cache, db, clientID) + itemCounts, newNormalCount, newHistoryCount, err := getItemCounts(ctx, cache, db, clientID) if err != nil { log.Error().Err(err).Msg("Get client's item count failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -259,7 +259,6 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // Map to save commit data type ID & mtime typeMtimeMap := make(map[int]int64) for i, v := range commitMsg.Entries { - var conflict, deleted bool entryRsp := &sync_pb.CommitResponse_EntryResponse{} commitRsp.Entryresponse[i] = entryRsp @@ -285,7 +284,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c *entityToCommit.Version = *entityToCommit.Mtime if *entityToCommit.DataType == datastore.HistoryTypeID { // Check if item exists using client_unique_tag - isUpdateOp, err = db.HasItem(clientID, *entityToCommit.ClientDefinedUniqueTag) + isUpdateOp, err = db.HasItem(ctx, clientID, *entityToCommit.ClientDefinedUniqueTag) if err != nil { log.Error().Err(err).Msg("Insert history sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -295,6 +294,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c } } + var interimErr error if !isUpdateOp { // Create if currentNormalItemCount+currentHistoryItemCount+newNormalCount+newHistoryCount >= maxClientObjectQuota+boostedQuotaAddition { rspType := sync_pb.CommitResponse_OVER_QUOTA @@ -307,7 +307,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // Insert all non-history items. For history items, ignore any items above history quoto // and lie to the client about the objects being synced instead of returning OVER_QUOTA // so the client can continue to sync other entities. - conflict, err = db.InsertSyncEntity(entityToCommit) + conflict, err := db.InsertSyncEntity(ctx, entityToCommit) if err != nil { log.Error().Err(err).Msg("Insert sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -326,13 +326,13 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c } if isHistoryRelatedItem { - newHistoryCount, err = cache.IncrementInterimCount(context.Background(), clientID, historyCountTypeStr, false) + newHistoryCount, interimErr = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, false) } else { - newNormalCount, err = cache.IncrementInterimCount(context.Background(), clientID, normalCountTypeStr, false) + newNormalCount, interimErr = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, false) } } } else { // Update - conflict, deleted, err = db.UpdateSyncEntity(entityToCommit, oldVersion) + conflict, deleted, err := db.UpdateSyncEntity(ctx, entityToCommit, oldVersion) if err != nil { log.Error().Err(err).Msg("Update sync entity failed") rspType := sync_pb.CommitResponse_TRANSIENT_ERROR @@ -347,16 +347,16 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c } if deleted { if isHistoryRelatedItem { - newHistoryCount, err = cache.IncrementInterimCount(context.Background(), clientID, historyCountTypeStr, true) + newHistoryCount, interimErr = cache.IncrementInterimCount(ctx, clientID, historyCountTypeStr, true) } else { - newNormalCount, err = cache.IncrementInterimCount(context.Background(), clientID, normalCountTypeStr, true) + newNormalCount, interimErr = cache.IncrementInterimCount(ctx, clientID, normalCountTypeStr, true) } } } - if err != nil { - log.Error().Err(err).Msg("Interim count update failed") + if interimErr != nil { + log.Error().Err(interimErr).Msg("Interim count update failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR - return &errCode, fmt.Errorf("interim count update failed: %w", err) + return &errCode, fmt.Errorf("interim count update failed: %w", interimErr) } typeMtimeMap[*entityToCommit.DataType] = *entityToCommit.Mtime @@ -368,7 +368,7 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c entryRsp.Mtime = entityToCommit.Mtime } - newNormalCount, newHistoryCount, err = getInterimItemCounts(cache, clientID, true) + newNormalCount, newHistoryCount, err = getInterimItemCounts(ctx, cache, clientID, true) if err != nil { log.Error().Err(err).Msg("Get interim item counts failed") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -377,10 +377,10 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // Save (clientID#dataType, mtime) into cache after writing into DB. for dataType, mtime := range typeMtimeMap { - cache.SetTypeMtime(context.Background(), clientID, dataType, mtime) + cache.SetTypeMtime(ctx, clientID, dataType, mtime) } - err = db.UpdateClientItemCount(itemCounts, newNormalCount, newHistoryCount) + err = db.UpdateClientItemCount(ctx, itemCounts, newNormalCount, newHistoryCount) if err != nil { // We only impose a soft quota limit on the item count for each client, so // we only log the error without further actions here. The reason of this @@ -397,18 +397,18 @@ func handleCommitRequest(cache *cache.Cache, commitMsg *sync_pb.CommitMessage, c // handleClearServerDataRequest handles clearing user data from the datastore and cache // and fills the response -func handleClearServerDataRequest(cache *cache.Cache, db datastore.Datastore, _ *sync_pb.ClearServerDataMessage, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { +func handleClearServerDataRequest(ctx context.Context, cache *cache.Cache, db datastore.Datastore, _ *sync_pb.ClearServerDataMessage, clientID string) (*sync_pb.SyncEnums_ErrorType, error) { errCode := sync_pb.SyncEnums_SUCCESS var err error - err = db.DisableSyncChain(clientID) + err = db.DisableSyncChain(ctx, clientID) if err != nil { log.Error().Err(err).Msg("Failed to disable sync chain") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR return &errCode, err } - syncEntities, err := db.ClearServerData(clientID) + syncEntities, err := db.ClearServerData(ctx, clientID) if err != nil { errCode = sync_pb.SyncEnums_TRANSIENT_ERROR return &errCode, err @@ -422,7 +422,7 @@ func handleClearServerDataRequest(cache *cache.Cache, db datastore.Datastore, _ } if len(typeMtimeCacheKeys) > 0 { - err = cache.Del(context.Background(), typeMtimeCacheKeys...) + err = cache.Del(ctx, typeMtimeCacheKeys...) if err != nil { log.Error().Err(err).Msg("Failed to clear cache") errCode = sync_pb.SyncEnums_TRANSIENT_ERROR @@ -435,7 +435,7 @@ func handleClearServerDataRequest(cache *cache.Cache, db datastore.Datastore, _ // HandleClientToServerMessage handles the protobuf ClientToServerMessage and // fills the protobuf ClientToServerResponse. -func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerMessage, pbRsp *sync_pb.ClientToServerResponse, db datastore.Datastore, clientID string) error { +func HandleClientToServerMessage(ctx context.Context, cache *cache.Cache, pb *sync_pb.ClientToServerMessage, pbRsp *sync_pb.ClientToServerResponse, db datastore.Datastore, clientID string) error { // Create ClientToServerResponse and fill general fields for both GU and // Commit. pbRsp.StoreBirthday = aws.String(storeBirthday) @@ -449,7 +449,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM } else if *pb.MessageContents == sync_pb.ClientToServerMessage_GET_UPDATES { guRsp := &sync_pb.GetUpdatesResponse{} pbRsp.GetUpdates = guRsp - pbRsp.ErrorCode, err = handleGetUpdatesRequest(cache, pb.GetUpdates, guRsp, db, clientID) + pbRsp.ErrorCode, err = handleGetUpdatesRequest(ctx, cache, pb.GetUpdates, guRsp, db, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) @@ -463,7 +463,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM } else if *pb.MessageContents == sync_pb.ClientToServerMessage_COMMIT { commitRsp := &sync_pb.CommitResponse{} pbRsp.Commit = commitRsp - pbRsp.ErrorCode, err = handleCommitRequest(cache, pb.Commit, commitRsp, db, clientID) + pbRsp.ErrorCode, err = handleCommitRequest(context.TODO(), cache, pb.Commit, commitRsp, db, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) @@ -477,7 +477,7 @@ func HandleClientToServerMessage(cache *cache.Cache, pb *sync_pb.ClientToServerM } else if *pb.MessageContents == sync_pb.ClientToServerMessage_CLEAR_SERVER_DATA { csdRsp := &sync_pb.ClearServerDataResponse{} pbRsp.ClearServerData = csdRsp - pbRsp.ErrorCode, err = handleClearServerDataRequest(cache, db, pb.ClearServerData, clientID) + pbRsp.ErrorCode, err = handleClearServerDataRequest(context.Background(), cache, db, pb.ClearServerData, clientID) if err != nil { if pbRsp.ErrorCode != nil { pbRsp.ErrorMessage = aws.String(err.Error()) diff --git a/command/command_test.go b/command/command_test.go index e646d5c5..1f8bb3c9 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -233,7 +233,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { // Commit and check response. suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 2) @@ -253,7 +253,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { marker, sync_pb.SyncEnums_GU_TRIGGER, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -277,7 +277,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) @@ -299,7 +299,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { marker, sync_pb.SyncEnums_GU_TRIGGER, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -324,7 +324,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { msg = getClientToServerCommitMsg(entries) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 2) @@ -340,7 +340,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_Basic() { marker, sync_pb.SyncEnums_GU_TRIGGER, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -357,7 +357,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_NewClient() { rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) @@ -409,7 +409,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceed rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, testCase.clientID), "HandleClientToServerMessage should succeed for device %d", i) suite.Equal(sync_pb.SyncEnums_SUCCESS, *rsp.ErrorCode, "device %d should succeed", i) suite.NotNil(rsp.GetUpdates, "device %d should have GetUpdates response", i) @@ -419,7 +419,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceed commitMsg := getClientToServerCommitMsg([]*sync_pb.SyncEntity{deviceEntry}) commitRsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, commitMsg, commitRsp, suite.dynamo, testCase.clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, commitMsg, commitRsp, suite.dynamo, testCase.clientID), "Commit device info should succeed for device %d", i) suite.Equal(sync_pb.SyncEnums_SUCCESS, *commitRsp.ErrorCode, "Commit device info should succeed for device %d", i) } @@ -427,7 +427,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceed // should get THROTTLED error when device limit is exceeded rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, testCase.clientID), "HandleClientToServerMessage should succeed") suite.Equal(sync_pb.SyncEnums_THROTTLED, *rsp.ErrorCode, "errorCode should be THROTTLED") suite.Require().NotNil(rsp.ErrorMessage, "error message should be present") @@ -448,7 +448,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_GUBatchSize() { // Commit and check response. suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 4) @@ -472,7 +472,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 2) @@ -497,7 +497,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 4) @@ -518,7 +518,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 2) @@ -535,7 +535,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 2) @@ -555,7 +555,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_QuotaLimit() { rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 4) @@ -577,7 +577,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo msg := getClientToServerCommitMsg([]*sync_pb.SyncEntity{child0}) rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 1) @@ -606,7 +606,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 6) @@ -621,7 +621,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_ReplaceParentIDTo marker, sync_pb.SyncEnums_GU_TRIGGER, true, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Len(rsp.GetUpdates.Entries, 6) @@ -651,7 +651,7 @@ func insertSyncEntitiesWithoutUpdateCache( for _, entry := range entries { dbEntry, err := datastore.CreateDBSyncEntity(entry, nil, clientID) suite.Require().NoError(err, "Create db entity from pb entity should succeed") - _, err = suite.dynamo.InsertSyncEntity(dbEntry) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), dbEntry) suite.Require().NoError(err, "Insert sync entity should succeed") val, err := suite.cache.Get(context.Background(), clientID+"#"+strconv.Itoa(*dbEntry.DataType), false) @@ -675,7 +675,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 3) @@ -716,7 +716,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba marker, sync_pb.SyncEnums_PERIODIC, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Empty(rsp.GetUpdates.Entries) @@ -739,7 +739,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 1) @@ -757,7 +757,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ba marker, sync_pb.SyncEnums_PERIODIC, false, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Len(rsp.GetUpdates.Entries, 2) @@ -778,7 +778,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Sk rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 1) @@ -805,7 +805,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Sk marker, sync_pb.SyncEnums_GU_TRIGGER, true, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Len(rsp.GetUpdates.Entries, 1) @@ -824,7 +824,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch rsp := &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, true) suite.Len(rsp.Commit.Entryresponse, 2) @@ -847,7 +847,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch marker, sync_pb.SyncEnums_PERIODIC, true, &clientBatch) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Len(rsp.GetUpdates.Entries, 2) @@ -864,7 +864,7 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_TypeMtimeCache_Ch marker, sync_pb.SyncEnums_PERIODIC, true, nil) rsp = &sync_pb.ClientToServerResponse{} suite.Require().NoError( - command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, clientID), + command.HandleClientToServerMessage(context.Background(), suite.cache, msg, rsp, suite.dynamo, clientID), "HandleClientToServerMessage should succeed") assertCommonResponse(suite, rsp, false) suite.Require().Len(rsp.GetUpdates.Entries, 1) diff --git a/command/server_defined_unique_entity.go b/command/server_defined_unique_entity.go index 8c81f310..60927924 100644 --- a/command/server_defined_unique_entity.go +++ b/command/server_defined_unique_entity.go @@ -1,6 +1,7 @@ package command import ( + "context" "fmt" "time" @@ -42,11 +43,11 @@ func createServerDefinedUniqueEntity(name string, serverDefinedTag string, clien // InsertServerDefinedUniqueEntities inserts the server defined unique tag // entities if it is not in the DB yet for a specific client. -func InsertServerDefinedUniqueEntities(db datastore.Datastore, clientID string) error { +func InsertServerDefinedUniqueEntities(ctx context.Context, db datastore.Datastore, clientID string) error { var entities []*datastore.SyncEntity // Check if they're existed already for this client. // If yes, just return directly. - ready, err := db.HasServerDefinedUniqueTag(clientID, nigoriTag) + ready, err := db.HasServerDefinedUniqueTag(ctx, clientID, nigoriTag) if err != nil { return fmt.Errorf("error checking if entity with a server tag existed: %w", err) } @@ -90,7 +91,7 @@ func InsertServerDefinedUniqueEntities(db datastore.Datastore, clientID string) } // Start a transaction to insert all server defined unique entities - err = db.InsertSyncEntitiesWithServerTags(entities) + err = db.InsertSyncEntitiesWithServerTags(ctx, entities) if err != nil { return fmt.Errorf("error inserting entities with server tags: %w", err) } diff --git a/command/server_defined_unique_entity_test.go b/command/server_defined_unique_entity_test.go index ea5119c3..21f7b977 100644 --- a/command/server_defined_unique_entity_test.go +++ b/command/server_defined_unique_entity_test.go @@ -1,6 +1,7 @@ package command_test import ( + "context" "sort" "testing" @@ -47,10 +48,10 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TearDownTest() { func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEntities() { suite.Require().NoError( - command.InsertServerDefinedUniqueEntities(suite.dynamo, "client1"), + command.InsertServerDefinedUniqueEntities(context.Background(), suite.dynamo, "client1"), "InsertServerDefinedUniqueEntities should succeed") suite.Require().NoError( - command.InsertServerDefinedUniqueEntities(suite.dynamo, "client1"), + command.InsertServerDefinedUniqueEntities(context.Background(), suite.dynamo, "client1"), "InsertServerDefinedUniqueEntities again for a same client should succeed") expectedSyncAttrsMap := map[string]*SyncAttrs{ @@ -157,7 +158,7 @@ func (suite *ServerDefinedUniqueEntityTestSuite) TestInsertServerDefinedUniqueEn suite.Empty(expectedSyncAttrsMap) suite.Require().NoError( - command.InsertServerDefinedUniqueEntities(suite.dynamo, "client2"), + command.InsertServerDefinedUniqueEntities(context.Background(), suite.dynamo, "client2"), "InsertServerDefinedUniqueEntities should succeed for another client") } diff --git a/controller/controller.go b/controller/controller.go index 382d17e7..9c7b8365 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -72,7 +72,7 @@ func Command(cache *cache.Cache, db datastore.Datastore) http.HandlerFunc { } pbRsp := &sync_pb.ClientToServerResponse{} - err = command.HandleClientToServerMessage(cache, pb, pbRsp, db, clientID) + err = command.HandleClientToServerMessage(ctx, cache, pb, pbRsp, db, clientID) if err != nil { log.Error().Err(err).Msg("Handle command message failed") http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/datastore/datastore.go b/datastore/datastore.go index 5e3e5451..217642f9 100644 --- a/datastore/datastore.go +++ b/datastore/datastore.go @@ -1,30 +1,32 @@ package datastore +import "context" + // Datastore abstracts over the underlying datastore. type Datastore interface { // Insert a new sync entity. - InsertSyncEntity(entity *SyncEntity) (bool, error) + InsertSyncEntity(ctx context.Context, entity *SyncEntity) (bool, error) // Insert a series of sync entities in a write transaction. - InsertSyncEntitiesWithServerTags(entities []*SyncEntity) error + InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*SyncEntity) error // Update an existing sync entity. - UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) + UpdateSyncEntity(ctx context.Context, entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) // Get updates for a specific type which are modified after the time of // client token for a given client. Besides the array of sync entities, a // boolean value indicating whether there are more updates to query in the // next batch is returned. - GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) + GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) // Check if a server-defined unique tag is in the datastore. - HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) + HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (bool, error) // Get the count of sync items for a client. - GetClientItemCount(clientID string) (*ClientItemCounts, error) + GetClientItemCount(ctx context.Context, clientID string) (*ClientItemCounts, error) // Update the count of sync items for a client. - UpdateClientItemCount(counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error + UpdateClientItemCount(ctx context.Context, counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error // ClearServerData deletes all items for a given clientID - ClearServerData(clientID string) ([]SyncEntity, error) + ClearServerData(ctx context.Context, clientID string) ([]SyncEntity, error) // DisableSyncChain marks a chain as disabled so no further updates or commits can happen - DisableSyncChain(clientID string) error + DisableSyncChain(ctx context.Context, clientID string) error // IsSyncChainDisabled checks whether a given sync chain is deleted - IsSyncChainDisabled(clientID string) (bool, error) + IsSyncChainDisabled(ctx context.Context, clientID string) (bool, error) // Checks if sync item exists for a client - HasItem(clientID string, ID string) (bool, error) + HasItem(ctx context.Context, clientID string, ID string) (bool, error) } diff --git a/datastore/datastoretest/mock_datastore.go b/datastore/datastoretest/mock_datastore.go index 5cda4da5..150d3b0c 100644 --- a/datastore/datastoretest/mock_datastore.go +++ b/datastore/datastoretest/mock_datastore.go @@ -1,6 +1,8 @@ package datastoretest import ( + "context" + "github.com/stretchr/testify/mock" "github.com/brave/go-sync/datastore" @@ -12,66 +14,66 @@ type MockDatastore struct { } // InsertSyncEntity mocks calls to InsertSyncEntity -func (m *MockDatastore) InsertSyncEntity(entity *datastore.SyncEntity) (bool, error) { - args := m.Called(entity) +func (m *MockDatastore) InsertSyncEntity(ctx context.Context, entity *datastore.SyncEntity) (bool, error) { + args := m.Called(ctx, entity) return args.Bool(0), args.Error(1) } // InsertSyncEntitiesWithServerTags mocks calls to InsertSyncEntitiesWithServerTags -func (m *MockDatastore) InsertSyncEntitiesWithServerTags(entities []*datastore.SyncEntity) error { - args := m.Called(entities) +func (m *MockDatastore) InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*datastore.SyncEntity) error { + args := m.Called(ctx, entities) return args.Error(0) } // UpdateSyncEntity mocks calls to UpdateSyncEntity -func (m *MockDatastore) UpdateSyncEntity(entity *datastore.SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { - args := m.Called(entity, oldVersion) +func (m *MockDatastore) UpdateSyncEntity(ctx context.Context, entity *datastore.SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { + args := m.Called(ctx, entity, oldVersion) return args.Bool(0), args.Bool(1), args.Error(2) } // GetUpdatesForType mocks calls to GetUpdatesForType -func (m *MockDatastore) GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []datastore.SyncEntity, error) { - args := m.Called(dataType, clientToken, fetchFolders, clientID, maxSize) +func (m *MockDatastore) GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []datastore.SyncEntity, error) { + args := m.Called(ctx, dataType, clientToken, fetchFolders, clientID, maxSize) return args.Bool(0), args.Get(1).([]datastore.SyncEntity), args.Error(2) } // HasServerDefinedUniqueTag mocks calls to HasServerDefinedUniqueTag -func (m *MockDatastore) HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) { - args := m.Called(clientID, tag) +func (m *MockDatastore) HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (bool, error) { + args := m.Called(ctx, clientID, tag) return args.Bool(0), args.Error(1) } -func (m *MockDatastore) HasItem(clientID string, ID string) (bool, error) { - args := m.Called(clientID, ID) +func (m *MockDatastore) HasItem(ctx context.Context, clientID string, ID string) (bool, error) { + args := m.Called(ctx, clientID, ID) return args.Bool(0), args.Error(1) } // GetClientItemCount mocks calls to GetClientItemCount -func (m *MockDatastore) GetClientItemCount(clientID string) (*datastore.ClientItemCounts, error) { - args := m.Called(clientID) +func (m *MockDatastore) GetClientItemCount(ctx context.Context, clientID string) (*datastore.ClientItemCounts, error) { + args := m.Called(ctx, clientID) return &datastore.ClientItemCounts{ClientID: clientID, ID: clientID}, args.Error(1) } // UpdateClientItemCount mocks calls to UpdateClientItemCount -func (m *MockDatastore) UpdateClientItemCount(counts *datastore.ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { - args := m.Called(counts, newNormalItemCount, newHistoryItemCount) +func (m *MockDatastore) UpdateClientItemCount(ctx context.Context, counts *datastore.ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { + args := m.Called(ctx, counts, newNormalItemCount, newHistoryItemCount) return args.Error(0) } // ClearServerData mocks calls to ClearServerData -func (m *MockDatastore) ClearServerData(clientID string) ([]datastore.SyncEntity, error) { - args := m.Called(clientID) +func (m *MockDatastore) ClearServerData(ctx context.Context, clientID string) ([]datastore.SyncEntity, error) { + args := m.Called(ctx, clientID) return args.Get(0).([]datastore.SyncEntity), args.Error(1) } // DisableSyncChain mocks calls to DisableSyncChain -func (m *MockDatastore) DisableSyncChain(clientID string) error { - args := m.Called(clientID) +func (m *MockDatastore) DisableSyncChain(ctx context.Context, clientID string) error { + args := m.Called(ctx, clientID) return args.Error(0) } // IsSyncChainDisabled mocks calls to IsSyncChainDisabled -func (m *MockDatastore) IsSyncChainDisabled(clientID string) (bool, error) { - args := m.Called(clientID) +func (m *MockDatastore) IsSyncChainDisabled(ctx context.Context, clientID string) (bool, error) { + args := m.Called(ctx, clientID) return args.Bool(0), args.Error(1) } diff --git a/datastore/instrumented_datastore.go b/datastore/instrumented_datastore.go index e2e82732..2abecd13 100644 --- a/datastore/instrumented_datastore.go +++ b/datastore/instrumented_datastore.go @@ -1,12 +1,13 @@ -package datastore +// Code generated by gowrap. DO NOT EDIT. +// template: ../.prom-gowrap.tmpl +// gowrap: http://github.com/hexdigest/gowrap -// DO NOT EDIT! -// This code is generated with http://github.com/hexdigest/gowrap tool -// using ../.prom-gowrap.tmpl template +package datastore -//go:generate gowrap gen -p github.com/brave/go-sync/datastore -i Datastore -t ../.prom-gowrap.tmpl -o instrumented_datastore.go +//go:generate gowrap gen -p github.com/brave/go-sync/datastore -i Datastore -t ../.prom-gowrap.tmpl -o instrumented_datastore.go -l "" import ( + "context" "time" "github.com/prometheus/client_golang/prometheus" @@ -38,7 +39,7 @@ func NewDatastoreWithPrometheus(base Datastore, instanceName string) DatastoreWi } // ClearServerData implements Datastore -func (_d DatastoreWithPrometheus) ClearServerData(clientID string) (sa1 []SyncEntity, err error) { +func (_d DatastoreWithPrometheus) ClearServerData(ctx context.Context, clientID string) (sa1 []SyncEntity, err error) { _since := time.Now() defer func() { result := "ok" @@ -48,11 +49,11 @@ func (_d DatastoreWithPrometheus) ClearServerData(clientID string) (sa1 []SyncEn datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "ClearServerData", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.ClearServerData(clientID) + return _d.base.ClearServerData(ctx, clientID) } // DisableSyncChain implements Datastore -func (_d DatastoreWithPrometheus) DisableSyncChain(clientID string) (err error) { +func (_d DatastoreWithPrometheus) DisableSyncChain(ctx context.Context, clientID string) (err error) { _since := time.Now() defer func() { result := "ok" @@ -62,11 +63,11 @@ func (_d DatastoreWithPrometheus) DisableSyncChain(clientID string) (err error) datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "DisableSyncChain", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.DisableSyncChain(clientID) + return _d.base.DisableSyncChain(ctx, clientID) } // GetClientItemCount implements Datastore -func (_d DatastoreWithPrometheus) GetClientItemCount(clientID string) (counts *ClientItemCounts, err error) { +func (_d DatastoreWithPrometheus) GetClientItemCount(ctx context.Context, clientID string) (cp1 *ClientItemCounts, err error) { _since := time.Now() defer func() { result := "ok" @@ -76,11 +77,11 @@ func (_d DatastoreWithPrometheus) GetClientItemCount(clientID string) (counts *C datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetClientItemCount", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.GetClientItemCount(clientID) + return _d.base.GetClientItemCount(ctx, clientID) } // GetUpdatesForType implements Datastore -func (_d DatastoreWithPrometheus) GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (b1 bool, sa1 []SyncEntity, err error) { +func (_d DatastoreWithPrometheus) GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (b1 bool, sa1 []SyncEntity, err error) { _since := time.Now() defer func() { result := "ok" @@ -90,11 +91,11 @@ func (_d DatastoreWithPrometheus) GetUpdatesForType(dataType int, clientToken in datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetUpdatesForType", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.GetUpdatesForType(dataType, clientToken, fetchFolders, clientID, maxSize) + return _d.base.GetUpdatesForType(ctx, dataType, clientToken, fetchFolders, clientID, maxSize) } // HasItem implements Datastore -func (_d DatastoreWithPrometheus) HasItem(clientID string, ID string) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) HasItem(ctx context.Context, clientID string, ID string) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -104,11 +105,11 @@ func (_d DatastoreWithPrometheus) HasItem(clientID string, ID string) (b1 bool, datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "HasItem", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.HasItem(clientID, ID) + return _d.base.HasItem(ctx, clientID, ID) } // HasServerDefinedUniqueTag implements Datastore -func (_d DatastoreWithPrometheus) HasServerDefinedUniqueTag(clientID string, tag string) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -118,11 +119,11 @@ func (_d DatastoreWithPrometheus) HasServerDefinedUniqueTag(clientID string, tag datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "HasServerDefinedUniqueTag", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.HasServerDefinedUniqueTag(clientID, tag) + return _d.base.HasServerDefinedUniqueTag(ctx, clientID, tag) } // InsertSyncEntitiesWithServerTags implements Datastore -func (_d DatastoreWithPrometheus) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) (err error) { +func (_d DatastoreWithPrometheus) InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*SyncEntity) (err error) { _since := time.Now() defer func() { result := "ok" @@ -132,11 +133,11 @@ func (_d DatastoreWithPrometheus) InsertSyncEntitiesWithServerTags(entities []*S datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "InsertSyncEntitiesWithServerTags", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.InsertSyncEntitiesWithServerTags(entities) + return _d.base.InsertSyncEntitiesWithServerTags(ctx, entities) } // InsertSyncEntity implements Datastore -func (_d DatastoreWithPrometheus) InsertSyncEntity(entity *SyncEntity) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) InsertSyncEntity(ctx context.Context, entity *SyncEntity) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -146,11 +147,11 @@ func (_d DatastoreWithPrometheus) InsertSyncEntity(entity *SyncEntity) (b1 bool, datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "InsertSyncEntity", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.InsertSyncEntity(entity) + return _d.base.InsertSyncEntity(ctx, entity) } // IsSyncChainDisabled implements Datastore -func (_d DatastoreWithPrometheus) IsSyncChainDisabled(clientID string) (b1 bool, err error) { +func (_d DatastoreWithPrometheus) IsSyncChainDisabled(ctx context.Context, clientID string) (b1 bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -160,11 +161,11 @@ func (_d DatastoreWithPrometheus) IsSyncChainDisabled(clientID string) (b1 bool, datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "IsSyncChainDisabled", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.IsSyncChainDisabled(clientID) + return _d.base.IsSyncChainDisabled(ctx, clientID) } // UpdateClientItemCount implements Datastore -func (_d DatastoreWithPrometheus) UpdateClientItemCount(counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) (err error) { +func (_d DatastoreWithPrometheus) UpdateClientItemCount(ctx context.Context, counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) (err error) { _since := time.Now() defer func() { result := "ok" @@ -174,11 +175,11 @@ func (_d DatastoreWithPrometheus) UpdateClientItemCount(counts *ClientItemCounts datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "UpdateClientItemCount", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.UpdateClientItemCount(counts, newNormalItemCount, newHistoryItemCount) + return _d.base.UpdateClientItemCount(ctx, counts, newNormalItemCount, newHistoryItemCount) } // UpdateSyncEntity implements Datastore -func (_d DatastoreWithPrometheus) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { +func (_d DatastoreWithPrometheus) UpdateSyncEntity(ctx context.Context, entity *SyncEntity, oldVersion int64) (conflict bool, deleted bool, err error) { _since := time.Now() defer func() { result := "ok" @@ -188,5 +189,5 @@ func (_d DatastoreWithPrometheus) UpdateSyncEntity(entity *SyncEntity, oldVersio datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "UpdateSyncEntity", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.UpdateSyncEntity(entity, oldVersion) + return _d.base.UpdateSyncEntity(ctx, entity, oldVersion) } diff --git a/datastore/item_count.go b/datastore/item_count.go index b24d3f24..a5b2f671 100644 --- a/datastore/item_count.go +++ b/datastore/item_count.go @@ -51,7 +51,7 @@ func (counts *ClientItemCounts) SumHistoryCounts() int { counts.HistoryItemCountPeriod4 } -func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCounts) error { +func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(ctx context.Context, counts *ClientItemCounts) error { now := time.Now().Unix() if counts.Version < CurrentCountVersion { if counts.ItemCount > 0 { @@ -75,7 +75,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou TableName: aws.String(Table), Select: types.SelectCount, } - out, err := dynamo.Query(context.TODO(), historyCountInput) + out, err := dynamo.Query(ctx, historyCountInput) if err != nil { return fmt.Errorf("error querying history item count: %w", err) } @@ -101,7 +101,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou TableName: aws.String(Table), Select: types.SelectCount, } - out, err = dynamo.Query(context.TODO(), normalCountInput) + out, err = dynamo.Query(ctx, normalCountInput) if err != nil { return fmt.Errorf("error querying history item count: %w", err) } @@ -129,7 +129,7 @@ func (dynamo *Dynamo) initRealCountsAndUpdateHistoryCounts(counts *ClientItemCou // GetClientItemCount returns the count of non-deleted sync items stored for // a given client. -func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, error) { +func (dynamo *Dynamo) GetClientItemCount(ctx context.Context, clientID string) (*ClientItemCounts, error) { primaryKey := PrimaryKey{ClientID: clientID, ID: clientID} key, err := attributevalue.MarshalMap(primaryKey) if err != nil { @@ -141,7 +141,7 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return nil, fmt.Errorf("error getting an item-count item: %w", err) } @@ -157,7 +157,7 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er clientItemCounts.ID = clientID } - if err = dynamo.initRealCountsAndUpdateHistoryCounts(clientItemCounts); err != nil { + if err = dynamo.initRealCountsAndUpdateHistoryCounts(ctx, clientItemCounts); err != nil { return nil, err } @@ -166,7 +166,7 @@ func (dynamo *Dynamo) GetClientItemCount(clientID string) (*ClientItemCounts, er // UpdateClientItemCount updates the count of non-deleted sync items for a // given client stored in the dynamoDB. -func (dynamo *Dynamo) UpdateClientItemCount(counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { +func (dynamo *Dynamo) UpdateClientItemCount(ctx context.Context, counts *ClientItemCounts, newNormalItemCount int, newHistoryItemCount int) error { counts.HistoryItemCountPeriod4 += newHistoryItemCount counts.ItemCount += newNormalItemCount @@ -180,7 +180,7 @@ func (dynamo *Dynamo) UpdateClientItemCount(counts *ClientItemCounts, newNormalI TableName: aws.String(Table), } - _, err = dynamo.PutItem(context.TODO(), input) + _, err = dynamo.PutItem(ctx, input) if err != nil { return fmt.Errorf("error updating item-count item in dynamoDB: %w", err) } diff --git a/datastore/item_count_test.go b/datastore/item_count_test.go index 54f307c2..ada6f745 100644 --- a/datastore/item_count_test.go +++ b/datastore/item_count_test.go @@ -1,6 +1,7 @@ package datastore_test import ( + "context" "sort" "testing" @@ -42,17 +43,17 @@ func (suite *ItemCountTestSuite) TestGetClientItemCount() { for _, item := range items { existing := datastore.ClientItemCounts{ClientID: item.ClientID, ID: item.ID, Version: datastore.CurrentCountVersion} suite.Require().NoError( - suite.dynamo.UpdateClientItemCount(&existing, item.ItemCount, 0)) + suite.dynamo.UpdateClientItemCount(context.Background(), &existing, item.ItemCount, 0)) } for _, item := range items { - count, err := suite.dynamo.GetClientItemCount(item.ClientID) + count, err := suite.dynamo.GetClientItemCount(context.Background(), item.ClientID) suite.Require().NoError(err, "GetClientItemCount should succeed") suite.Equal(count.ItemCount, item.ItemCount, "ItemCount should match") } // Non-exist client item count should succeed with count = 0. - count, err := suite.dynamo.GetClientItemCount("client3") + count, err := suite.dynamo.GetClientItemCount(context.Background(), "client3") suite.Require().NoError(err, "Get non-exist ClientItemCount should succeed") suite.Equal(0, count.ItemCount) } @@ -69,10 +70,10 @@ func (suite *ItemCountTestSuite) TestUpdateClientItemCount() { } for _, item := range items { - count, err := suite.dynamo.GetClientItemCount(item.ClientID) + count, err := suite.dynamo.GetClientItemCount(context.Background(), item.ClientID) suite.Require().NoError(err) suite.Require().NoError( - suite.dynamo.UpdateClientItemCount(count, item.ItemCount, 0)) + suite.dynamo.UpdateClientItemCount(context.Background(), count, item.ItemCount, 0)) } clientCountItems, err := datastoretest.ScanClientItemCounts(suite.dynamo) diff --git a/datastore/sync_entity.go b/datastore/sync_entity.go index e87d8926..144ea3ec 100644 --- a/datastore/sync_entity.go +++ b/datastore/sync_entity.go @@ -159,7 +159,7 @@ func NewServerClientUniqueTagItemQuery(clientID string, tag string, isServer boo // write a sync item along with a tag item to ensure the uniqueness of the // client tag. Otherwise, only a sync item is written into DB without using // transactions. -func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { +func (dynamo *Dynamo) InsertSyncEntity(ctx context.Context, entity *SyncEntity) (bool, error) { // Create a condition for inserting new items only. cond := expression.AttributeNotExists(expression.Name(pk)) expr, err := expression.NewBuilder().WithCondition(cond).Build() @@ -204,7 +204,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { items = append(items, tagItem) items = append(items, syncItem) - _, err = dynamo.TransactWriteItems(context.TODO(), + _, err = dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { // Return conflict if insert condition failed. @@ -234,7 +234,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { ConditionExpression: expr.Condition(), TableName: aws.String(Table), } - _, err = dynamo.PutItem(context.TODO(), input) + _, err = dynamo.PutItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling PutItem to insert sync item: %w", err) } @@ -243,7 +243,7 @@ func (dynamo *Dynamo) InsertSyncEntity(entity *SyncEntity) (bool, error) { // HasServerDefinedUniqueTag check the tag item to see if there is already a // tag item exists with the tag value for a specific client. -func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bool, error) { +func (dynamo *Dynamo) HasServerDefinedUniqueTag(ctx context.Context, clientID string, tag string) (bool, error) { tagItem := NewServerClientUniqueTagItemQuery(clientID, tag, true) key, err := attributevalue.MarshalMap(tagItem) if err != nil { @@ -256,7 +256,7 @@ func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bo TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if server tag existed: %w", err) } @@ -264,7 +264,7 @@ func (dynamo *Dynamo) HasServerDefinedUniqueTag(clientID string, tag string) (bo return out.Item != nil, nil } -func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { +func (dynamo *Dynamo) HasItem(ctx context.Context, clientID string, ID string) (bool, error) { primaryKey := PrimaryKey{ClientID: clientID, ID: ID} key, err := attributevalue.MarshalMap(primaryKey) @@ -278,7 +278,7 @@ func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if item existed: %w", err) } @@ -290,7 +290,7 @@ func (dynamo *Dynamo) HasItem(clientID string, ID string) (bool, error) { // server-defined unique tags. To ensure the uniqueness, for each sync entity, // we will write a tag item and a sync item. Items for all the entities in the // array would be written into DB in one transaction. -func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) error { +func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(ctx context.Context, entities []*SyncEntity) error { items := make([]types.TransactWriteItem, 0, len(entities)*2) for _, entity := range entities { // Create a condition for inserting new items only. @@ -335,7 +335,7 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e items = append(items, syncItem) } - _, err := dynamo.TransactWriteItems(context.TODO(), + _, err := dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { return fmt.Errorf("error writing sync entities with server tags in a transaction: %w", err) @@ -344,7 +344,7 @@ func (dynamo *Dynamo) InsertSyncEntitiesWithServerTags(entities []*SyncEntity) e } // DisableSyncChain marks a chain as disabled so no further updates or commits can happen -func (dynamo *Dynamo) DisableSyncChain(clientID string) error { +func (dynamo *Dynamo) DisableSyncChain(ctx context.Context, clientID string) error { now := aws.Int64(time.Now().UnixMilli()) disabledMarker := DisabledMarkerItem{ ClientID: clientID, @@ -364,7 +364,7 @@ func (dynamo *Dynamo) DisableSyncChain(clientID string) error { TableName: aws.String(Table), } - _, err = dynamo.PutItem(context.TODO(), markerInput) + _, err = dynamo.PutItem(ctx, markerInput) if err != nil { return fmt.Errorf("error calling PutItem to insert sync item: %w", err) } @@ -373,7 +373,7 @@ func (dynamo *Dynamo) DisableSyncChain(clientID string) error { } // ClearServerData deletes all items for a given clientID -func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { +func (dynamo *Dynamo) ClearServerData(ctx context.Context, clientID string) ([]SyncEntity, error) { syncEntities := []SyncEntity{} pkb := expression.Key(pk) pkv := expression.Value(clientID) @@ -392,7 +392,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { TableName: aws.String(Table), } - out, err := dynamo.Query(context.TODO(), input) + out, err := dynamo.Query(ctx, input) if err != nil { return syncEntities, fmt.Errorf("error doing query to get updates: %w", err) } @@ -459,7 +459,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { } } - _, err = dynamo.TransactWriteItems(context.TODO(), &dynamodb.TransactWriteItemsInput{TransactItems: items}) + _, err = dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { return syncEntities, fmt.Errorf("error deleting sync entities for client %s: %w", clientID, err) } @@ -469,7 +469,7 @@ func (dynamo *Dynamo) ClearServerData(clientID string) ([]SyncEntity, error) { } // IsSyncChainDisabled checks whether a given sync chain has been deleted -func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { +func (dynamo *Dynamo) IsSyncChainDisabled(ctx context.Context, clientID string) (bool, error) { key, err := attributevalue.MarshalMap(DisabledMarkerItemQuery{ ClientID: clientID, ID: disabledChainID, @@ -483,7 +483,7 @@ func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { TableName: aws.String(Table), } - out, err := dynamo.GetItem(context.TODO(), input) + out, err := dynamo.GetItem(ctx, input) if err != nil { return false, fmt.Errorf("error calling GetItem to check if sync chain disabled: %w", err) } @@ -492,7 +492,7 @@ func (dynamo *Dynamo) IsSyncChainDisabled(clientID string) (bool, error) { } // UpdateSyncEntity updates a sync item in dynamoDB. -func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bool, bool, error) { +func (dynamo *Dynamo) UpdateSyncEntity(ctx context.Context, entity *SyncEntity, oldVersion int64) (bool, bool, error) { primaryKey := PrimaryKey{ClientID: entity.ClientID, ID: entity.ID} key, err := attributevalue.MarshalMap(primaryKey) if err != nil { @@ -567,7 +567,7 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo items = append(items, updateSyncItem) items = append(items, deleteTagItem) - _, err = dynamo.TransactWriteItems(context.TODO(), + _, err = dynamo.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{TransactItems: items}) if err != nil { // Return conflict if the update condition fails. @@ -599,7 +599,7 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo TableName: aws.String(Table), } - out, err := dynamo.UpdateItem(context.TODO(), input) + out, err := dynamo.UpdateItem(ctx, input) if err != nil { var conditionalCheckFailedException *types.ConditionalCheckFailedException if errors.As(err, &conditionalCheckFailedException) { @@ -631,7 +631,7 @@ func (dynamo *Dynamo) UpdateSyncEntity(entity *SyncEntity, oldVersion int64) (bo // To do this in dynamoDB, we use (ClientID, DataType#Mtime) as GSI to get a // list of (ClientID, ID) primary keys with the given condition, then read the // actual sync item using the list of primary keys. -func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) { +func (dynamo *Dynamo) GetUpdatesForType(ctx context.Context, dataType int, clientToken int64, fetchFolders bool, clientID string, maxSize int64) (bool, []SyncEntity, error) { syncEntities := []SyncEntity{} // Get (ClientID, ID) pairs which are updates after mtime for a data type, @@ -667,7 +667,7 @@ func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFo Limit: aws.Int32(int32(maxSize)), } - out, err := dynamo.Query(context.TODO(), input) + out, err := dynamo.Query(ctx, input) if err != nil { return false, syncEntities, fmt.Errorf("error doing query to get updates: %w", err) } @@ -696,7 +696,7 @@ func (dynamo *Dynamo) GetUpdatesForType(dataType int, clientToken int64, fetchFo // Use paginator to automatically handle UnprocessedKeys paginator := dynamodb.NewBatchGetItemPaginator(dynamo.Client, batchInput) for paginator.HasMorePages() { - batchOut, err := paginator.NextPage(context.TODO()) + batchOut, err := paginator.NextPage(ctx) if err != nil { return false, syncEntities, fmt.Errorf("error getting update items in a batch: %w", err) } diff --git a/datastore/sync_entity_test.go b/datastore/sync_entity_test.go index 5c6d8919..231af3b9 100644 --- a/datastore/sync_entity_test.go +++ b/datastore/sync_entity_test.go @@ -1,6 +1,7 @@ package datastore_test import ( + "context" "encoding/json" "sort" "strconv" @@ -81,11 +82,11 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { } entity2 := entity1 entity2.ID = "id2" - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity with other ID should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity1) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().Error(err, "InsertSyncEntity with the same ClientID and ID should fail") // Each InsertSyncEntity without client tag should result in one sync item saved. @@ -102,20 +103,20 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { entity3 := entity1 entity3.ID = "id3" entity3.ClientDefinedUniqueTag = aws.String("tag1") - _, err = suite.dynamo.InsertSyncEntity(&entity3) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity3) suite.Require().NoError(err, "InsertSyncEntity should succeed") // Insert entity with different tag for same ClientID should succeed. entity4 := entity3 entity4.ID = "id4" entity4.ClientDefinedUniqueTag = aws.String("tag2") - _, err = suite.dynamo.InsertSyncEntity(&entity4) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity4) suite.Require().NoError(err, "InsertSyncEntity with different server tag should succeed") // Insert entity with the same client tag and ClientID should fail with conflict. entity4Copy := entity4 entity4Copy.ID = "id4_copy" - conflict, err := suite.dynamo.InsertSyncEntity(&entity4Copy) + conflict, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity4Copy) suite.Require().Error(err, "InsertSyncEntity with the same client tag and ClientID should fail") suite.True(conflict, "Return conflict for duplicate client tag") @@ -123,7 +124,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntity() { entity5 := entity3 entity5.ClientID = "client2" entity5.ID = "id5" - _, err = suite.dynamo.InsertSyncEntity(&entity5) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity5) suite.Require().NoError(err, "InsertSyncEntity with the same client tag for another client should succeed") @@ -177,22 +178,22 @@ func (suite *SyncEntityTestSuite) TestHasServerDefinedUniqueTag() { tag2.ServerDefinedUniqueTag = aws.String("tag2") entities := []*datastore.SyncEntity{&tag1, &tag2} - err := suite.dynamo.InsertSyncEntitiesWithServerTags(entities) + err := suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities) suite.Require().NoError(err, "Insert sync entities should succeed") - hasTag, err := suite.dynamo.HasServerDefinedUniqueTag("client1", "tag1") + hasTag, err := suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client1", "tag1") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.True(hasTag) - hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client1", "tag2") + hasTag, err = suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client1", "tag2") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.False(hasTag) - hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client2", "tag1") + hasTag, err = suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client2", "tag1") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.False(hasTag) - hasTag, err = suite.dynamo.HasServerDefinedUniqueTag("client2", "tag2") + hasTag, err = suite.dynamo.HasServerDefinedUniqueTag(context.Background(), "client2", "tag2") suite.Require().NoError(err, "HasServerDefinedUniqueTag should succeed") suite.True(hasTag) } @@ -215,24 +216,24 @@ func (suite *SyncEntityTestSuite) TestHasItem() { entity2.ClientID = "client2" entity2.ID = "id2" - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - hasTag, err := suite.dynamo.HasItem("client1", "id1") + hasTag, err := suite.dynamo.HasItem(context.Background(), "client1", "id1") suite.Require().NoError(err, "HasItem should succeed") suite.True(hasTag) - hasTag, err = suite.dynamo.HasItem("client2", "id2") + hasTag, err = suite.dynamo.HasItem(context.Background(), "client2", "id2") suite.Require().NoError(err, "HasItem should succeed") suite.True(hasTag) - hasTag, err = suite.dynamo.HasItem("client2", "id3") + hasTag, err = suite.dynamo.HasItem(context.Background(), "client2", "id3") suite.Require().NoError(err, "HasItem should succeed") suite.False(hasTag) - hasTag, err = suite.dynamo.HasItem("client3", "id2") + hasTag, err = suite.dynamo.HasItem(context.Background(), "client3", "id2") suite.Require().NoError(err, "HasItem should succeed") suite.False(hasTag) } @@ -255,7 +256,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { entity2.ID = "id2" entities := []*datastore.SyncEntity{&entity1, &entity2} suite.Require().Error( - suite.dynamo.InsertSyncEntitiesWithServerTags(entities), + suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities), "Insert with same ClientID and server tag would fail") // Check nothing is written to DB when it fails. @@ -272,7 +273,7 @@ func (suite *SyncEntityTestSuite) TestInsertSyncEntitiesWithServerTags() { entity3.ID = "id3" entities = []*datastore.SyncEntity{&entity1, &entity2, &entity3} suite.Require().NoError( - suite.dynamo.InsertSyncEntitiesWithServerTags(entities), + suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities), "InsertSyncEntitiesWithServerTags should succeed") // Scan DB and check all items are saved @@ -321,11 +322,11 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { entity2.ID = "id2" entity3 := entity1 entity3.ID = "id3" - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity3) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity3) suite.Require().NoError(err, "InsertSyncEntity should succeed") // Check sync entities are inserted correctly in DB. syncItems, err := datastoretest.ScanSyncEntities(suite.dynamo) @@ -340,7 +341,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity1.Deleted = aws.Bool(true) updateEntity1.DataTypeMtime = aws.String("123#23456789") updateEntity1.Specifics = []byte{3, 4} - conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) + conflict, deleted, err := suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.True(deleted, "Delete operation should return true") @@ -354,7 +355,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity2.ParentID = aws.String("parentID") updateEntity2.Name = aws.String("name") updateEntity2.NonUniqueName = aws.String("non_unique_name") - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, *entity2.Version) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity2, *entity2.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.False(deleted, "Non-delete operation should return false") @@ -364,7 +365,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { updateEntity3.ID = "id3" updateEntity3.Folder = nil updateEntity3.Deleted = nil - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity3, *entity3.Version) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity3, *entity3.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.False(deleted, "Non-delete operation should return false") @@ -374,7 +375,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_Basic() { // Update entity again with the wrong old version as (version mismatch) // should return false. - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, 12345678) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity2, 12345678) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.True(conflict, "Update with the same version should return conflict") suite.False(deleted, "Conflict operation should return false for delete") @@ -400,7 +401,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { DataTypeMtime: aws.String("123#12345678"), Specifics: []byte{1, 2}, } - conflict, err := suite.dynamo.InsertSyncEntity(&entity1) + conflict, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") suite.False(conflict, "Successful insert should not have conflict") @@ -408,7 +409,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { updateEntity1.Version = aws.Int64(2) updateEntity1.Folder = aws.Bool(true) updateEntity1.Mtime = aws.Int64(24242424) - conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, 1) + conflict, deleted, err := suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, 1) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.False(deleted, "Non-delete operation should return false") @@ -417,7 +418,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { // since the version number should be ignored updateEntity2 := updateEntity1 updateEntity2.Mtime = aws.Int64(42424242) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity2, 1) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity2, 1) suite.Require().NoError(err, "UpdateSyncEntity should not return an error") suite.False(conflict, "Successful update should not have conflict") suite.False(deleted, "Non-delete operation should return false") @@ -425,7 +426,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_HistoryType() { updateEntity3 := entity1 updateEntity3.Deleted = aws.Bool(true) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity3, 1) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity3, 1) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.True(deleted, "Delete operation should return true") @@ -450,7 +451,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { DataTypeMtime: aws.String("123#12345678"), Specifics: []byte{1, 2}, } - conflict, err := suite.dynamo.InsertSyncEntity(&entity1) + conflict, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") suite.False(conflict, "Successful insert should not have conflict") @@ -466,21 +467,21 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { updateEntity1.Folder = aws.Bool(true) updateEntity1.DataTypeMtime = aws.String("123#23456789") updateEntity1.Specifics = []byte{3, 4} - conflict, deleted, err := suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) + conflict, deleted, err := suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.False(deleted, "Non-delete operation should return false") // Soft-delete the item with wrong version should get conflict. updateEntity1.Deleted = aws.Bool(true) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity1, *entity1.Version) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, *entity1.Version) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.True(conflict, "Version mismatched update should have conflict") suite.False(deleted, "Failed delete operation should return false") // Soft-delete the item with matched version. updateEntity1.Version = aws.Int64(34567890) - conflict, deleted, err = suite.dynamo.UpdateSyncEntity(&updateEntity1, 23456789) + conflict, deleted, err = suite.dynamo.UpdateSyncEntity(context.Background(), &updateEntity1, 23456789) suite.Require().NoError(err, "UpdateSyncEntity should succeed") suite.False(conflict, "Successful update should not have conflict") suite.True(deleted, "Delete operation should return true") @@ -493,7 +494,7 @@ func (suite *SyncEntityTestSuite) TestUpdateSyncEntity_ReuseClientTag() { // Insert another item with the same client tag again. entity2 := entity1 entity2.ID = "id2" - conflict, err = suite.dynamo.InsertSyncEntity(&entity2) + conflict, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") suite.False(conflict, "Successful insert should not have conflict") @@ -541,63 +542,63 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { entity5.ID = "id5" entity5.ExpirationTime = aws.Int64(time.Now().Unix() - 300) - _, err := suite.dynamo.InsertSyncEntity(&entity1) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity1) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity2) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity2) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity3) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity3) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity4) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity4) suite.Require().NoError(err, "InsertSyncEntity should succeed") - _, err = suite.dynamo.InsertSyncEntity(&entity5) + _, err = suite.dynamo.InsertSyncEntity(context.Background(), &entity5) suite.Require().NoError(err, "InsertSyncEntity should succeed") // Get all updates for type 123 and client1 using token = 0. - hasChangesRemaining, syncItems, err := suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 100) + hasChangesRemaining, syncItems, err := suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity1, entity2}, syncItems) suite.False(hasChangesRemaining) // Get all updates for type 124 and client1 using token = 0. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(124, 0, true, "client1", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 124, 0, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity3}, syncItems) suite.False(hasChangesRemaining) // Get all updates for type 123 and client2 using token = 0. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client2", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client2", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity4}, syncItems) suite.False(hasChangesRemaining) // Get all updates for type 124 and client2 using token = 0. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(124, 0, true, "client2", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 124, 0, true, "client2", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Empty(syncItems) suite.False(hasChangesRemaining) // Test maxSize will limit the return entries size, and hasChangesRemaining // should be true when there are more updates available in the DB. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 1) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 1) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity1}, syncItems) suite.True(hasChangesRemaining) // Test when num of query items equal to the limit, hasChangesRemaining should // be true. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 2) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 2) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity1, entity2}, syncItems) suite.True(hasChangesRemaining) // Test fetchFolders will remove folder items if false - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, false, "client1", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, false, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity2}, syncItems) suite.False(hasChangesRemaining) // Get all updates for a type for a client using mtime of one item as token. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 12345678, true, "client1", 100) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 12345678, true, "client1", 100) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal([]datastore.SyncEntity{entity2}, syncItems) suite.False(hasChangesRemaining) @@ -624,13 +625,13 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { entity.ID = "id" + strconv.Itoa(i) entity.Mtime = aws.Int64(mtime) entity.DataTypeMtime = aws.String("123#" + strconv.FormatInt(*entity.Mtime, 10)) - _, err := suite.dynamo.InsertSyncEntity(&entity) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity) suite.Require().NoError(err, "InsertSyncEntity should succeed") expectedSyncItems = append(expectedSyncItems, entity) } // All items should be returned and sorted by Mtime. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 300) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 300) suite.Require().NoError(err, "GetUpdatesForType should succeed") sort.Sort(datastore.SyncEntityByMtime(expectedSyncItems)) suite.Equal(expectedSyncItems, syncItems) @@ -638,7 +639,7 @@ func (suite *SyncEntityTestSuite) TestGetUpdatesForType() { // Test that when maxGUBatchSize is smaller than total updates, the first n // items ordered by Mtime should be returned. - hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(123, 0, true, "client1", 200) + hasChangesRemaining, syncItems, err = suite.dynamo.GetUpdatesForType(context.Background(), 123, 0, true, "client1", 200) suite.Require().NoError(err, "GetUpdatesForType should succeed") suite.Equal(syncItems, expectedSyncItems[0:200]) suite.True(hasChangesRemaining) @@ -849,7 +850,7 @@ func (suite *SyncEntityTestSuite) TestCreatePBSyncEntity() { func (suite *SyncEntityTestSuite) TestDisableSyncChain() { clientID := "client1" id := "disabled_chain" - err := suite.dynamo.DisableSyncChain(clientID) + err := suite.dynamo.DisableSyncChain(context.Background(), clientID) suite.Require().NoError(err, "DisableSyncChain should succeed") e, err := datastoretest.ScanTagItems(suite.dynamo) suite.Require().NoError(err, "ScanTagItems should succeed") @@ -861,13 +862,13 @@ func (suite *SyncEntityTestSuite) TestDisableSyncChain() { func (suite *SyncEntityTestSuite) TestIsSyncChainDisabled() { clientID := "client1" - disabled, err := suite.dynamo.IsSyncChainDisabled(clientID) + disabled, err := suite.dynamo.IsSyncChainDisabled(context.Background(), clientID) suite.Require().NoError(err, "IsSyncChainDisabled should succeed") suite.False(disabled) - err = suite.dynamo.DisableSyncChain(clientID) + err = suite.dynamo.DisableSyncChain(context.Background(), clientID) suite.Require().NoError(err, "DisableSyncChain should succeed") - disabled, err = suite.dynamo.IsSyncChainDisabled(clientID) + disabled, err = suite.dynamo.IsSyncChainDisabled(context.Background(), clientID) suite.Require().NoError(err, "IsSyncChainDisabled should succeed") suite.True(disabled) } @@ -885,14 +886,14 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { Deleted: aws.Bool(false), DataTypeMtime: aws.String("123#12345678"), } - _, err := suite.dynamo.InsertSyncEntity(&entity) + _, err := suite.dynamo.InsertSyncEntity(context.Background(), &entity) suite.Require().NoError(err, "InsertSyncEntity should succeed") e, err := datastoretest.ScanSyncEntities(suite.dynamo) suite.Require().NoError(err, "ScanSyncEntities should succeed") suite.Len(e, 1) - e, err = suite.dynamo.ClearServerData(entity.ClientID) + e, err = suite.dynamo.ClearServerData(context.Background(), entity.ClientID) suite.Require().NoError(err, "ClearServerData should succeed") suite.Len(e, 1) @@ -918,7 +919,7 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { entity2.ServerDefinedUniqueTag = aws.String("tag2") entities := []*datastore.SyncEntity{&entity1, &entity2} suite.Require().NoError( - suite.dynamo.InsertSyncEntitiesWithServerTags(entities), + suite.dynamo.InsertSyncEntitiesWithServerTags(context.Background(), entities), "InsertSyncEntitiesWithServerTags should succeed") e, err = datastoretest.ScanSyncEntities(suite.dynamo) @@ -929,7 +930,7 @@ func (suite *SyncEntityTestSuite) TestClearServerData() { suite.Require().NoError(err, "ScanTagItems should succeed") suite.Len(t, 2, "No items should be written if fail") - e, err = suite.dynamo.ClearServerData(entity.ClientID) + e, err = suite.dynamo.ClearServerData(context.Background(), entity.ClientID) suite.Require().NoError(err, "ClearServerData should succeed") suite.Len(e, 4) diff --git a/middleware/disabled_chain.go b/middleware/disabled_chain.go index 7aeb266c..7898613f 100644 --- a/middleware/disabled_chain.go +++ b/middleware/disabled_chain.go @@ -28,7 +28,7 @@ func DisabledChain(next http.Handler) http.Handler { return } - disabled, err := db.IsSyncChainDisabled(clientID) + disabled, err := db.IsSyncChainDisabled(ctx, clientID) if err != nil { http.Error(w, "unable to complete request", http.StatusInternalServerError) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 087586f7..dc418093 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -26,9 +26,9 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // Active Chain datastore := new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(false, nil) ctx := context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) + datastore.On("IsSyncChainDisabled", ctx, clientID).Return(false, nil) next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) handler := middleware.DisabledChain(next) req, err := http.NewRequestWithContext(ctx, "POST", "v2/command/", bytes.NewBuffer([]byte{})) @@ -39,9 +39,9 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // Disabled chain datastore = new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(true, nil) ctx = context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) + datastore.On("IsSyncChainDisabled", ctx, clientID).Return(true, nil) next = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { suite.Fail("Should not reach this point") }) @@ -54,9 +54,9 @@ func (suite *MiddlewareTestSuite) TestDisabledChainMiddleware() { // DB error datastore = new(datastoretest.MockDatastore) - datastore.On("IsSyncChainDisabled", clientID).Return(false, errors.New("unable to query db")) ctx = context.WithValue(context.Background(), syncContext.ContextKeyClientID, clientID) ctx = context.WithValue(ctx, syncContext.ContextKeyDatastore, datastore) + datastore.On("IsSyncChainDisabled", ctx, clientID).Return(false, errors.New("unable to query db")) next = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) handler = middleware.DisabledChain(next) req, err = http.NewRequestWithContext(ctx, "POST", "v2/command/", bytes.NewBuffer([]byte{}))