diff --git a/command/command.go b/command/command.go index 30359c75..8b8b541b 100644 --- a/command/command.go +++ b/command/command.go @@ -27,7 +27,6 @@ const ( setSyncPollInterval int32 = 30 nigoriTypeID int32 = 47745 deviceInfoTypeID int = 154522 - maxActiveDevices int = 50 historyCountTypeStr string = "history" normalCountTypeStr string = "normal" ) @@ -56,8 +55,8 @@ func handleGetUpdatesRequest(cache *cache.Cache, guMsg *sync_pb.GetUpdatesMessag activeDevices++ } - // Error out when exceeds the limit. - if activeDevices >= maxActiveDevices { + // Error out when device limit has been reached. + if hasReachedDeviceLimit(activeDevices, clientID) { errCode = sync_pb.SyncEnums_THROTTLED return &errCode, errors.New("exceed limit of active devices in a chain") } diff --git a/command/command_test.go b/command/command_test.go index 0dbdf834..e646d5c5 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/binary" "encoding/json" + "fmt" "sort" "strconv" "strings" @@ -98,6 +99,15 @@ func getBookmarkSpecifics() *sync_pb.EntitySpecifics { } } +func getDeviceInfoSpecifics() *sync_pb.EntitySpecifics { + deviceInfoEntitySpecifics := &sync_pb.EntitySpecifics_DeviceInfo{ + DeviceInfo: &sync_pb.DeviceInfoSpecifics{}, + } + return &sync_pb.EntitySpecifics{ + SpecificsVariant: deviceInfoEntitySpecifics, + } +} + func getCommitEntity(id string, version int64, deleted bool, specifics *sync_pb.EntitySpecifics) *sync_pb.SyncEntity { return &sync_pb.SyncEntity{ IdString: aws.String(id), @@ -378,6 +388,53 @@ func (suite *CommandTestSuite) TestHandleClientToServerMessage_NewClient() { suite.Equal(expectedEncryptionKeys, rsp.GetUpdates.EncryptionKeys) } +func (suite *CommandTestSuite) TestHandleClientToServerMessage_DeviceLimitExceeded() { + highDeviceLimitClientID := "high_device_limit_client_id" + command.LoadHighDeviceLimitClientIDs(fmt.Sprintf("randomid,%s,anotherrandomid", highDeviceLimitClientID)) + + testCases := []struct { + clientID string + expectedDeviceLimit int + }{ + {clientID: "client_id_1", expectedDeviceLimit: 50}, + {clientID: highDeviceLimitClientID, expectedDeviceLimit: 100}, + } + + for _, testCase := range testCases { + // Simulate devices calling GetUpdates with NEW_CLIENT origin up to the expected device limit. + marker := getMarker(suite, []int64{0, 0}) + msg := getClientToServerGUMsg( + marker, sync_pb.SyncEnums_NEW_CLIENT, true, nil) + for i := 1; i <= testCase.expectedDeviceLimit; i++ { + rsp := &sync_pb.ClientToServerResponse{} + + suite.Require().NoError( + command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + "HandleClientToServerMessage should succeed for device %d", i) + suite.Equal(sync_pb.SyncEnums_SUCCESS, *rsp.ErrorCode, "device %d should succeed", i) + suite.NotNil(rsp.GetUpdates, "device %d should have GetUpdates response", i) + + // Commit a device info entity after GetUpdates + deviceEntry := getCommitEntity(fmt.Sprintf("device_%d", i), 0, false, getDeviceInfoSpecifics()) + commitMsg := getClientToServerCommitMsg([]*sync_pb.SyncEntity{deviceEntry}) + commitRsp := &sync_pb.ClientToServerResponse{} + suite.Require().NoError( + command.HandleClientToServerMessage(suite.cache, commitMsg, commitRsp, suite.dynamo, testCase.clientID), + "Commit device info should succeed for device %d", i) + suite.Equal(sync_pb.SyncEnums_SUCCESS, *commitRsp.ErrorCode, "Commit device info should succeed for device %d", i) + } + + // should get THROTTLED error when device limit is exceeded + rsp := &sync_pb.ClientToServerResponse{} + suite.Require().NoError( + command.HandleClientToServerMessage(suite.cache, msg, rsp, suite.dynamo, testCase.clientID), + "HandleClientToServerMessage should succeed") + suite.Equal(sync_pb.SyncEnums_THROTTLED, *rsp.ErrorCode, "errorCode should be THROTTLED") + suite.Require().NotNil(rsp.ErrorMessage, "error message should be present") + suite.Contains(*rsp.ErrorMessage, "exceed limit of active devices") + } +} + func (suite *CommandTestSuite) TestHandleClientToServerMessage_GUBatchSize() { // Commit a few items for testing. entries := []*sync_pb.SyncEntity{ diff --git a/command/device_limit.go b/command/device_limit.go new file mode 100644 index 00000000..15095054 --- /dev/null +++ b/command/device_limit.go @@ -0,0 +1,38 @@ +package command + +import ( + "os" + "strings" +) + +const ( + maxActiveDevices int = 50 + highMaxActiveDevices int = 100 +) + +var ( + highDeviceLimitClientIDs map[string]bool +) + +func init() { + clientIDsEnv := os.Getenv("HIGH_DEVICE_LIMIT_CLIENT_IDS") + LoadHighDeviceLimitClientIDs(clientIDsEnv) +} + +func LoadHighDeviceLimitClientIDs(clientIDList string) { + highDeviceLimitClientIDs = make(map[string]bool) + if clientIDList != "" { + ids := strings.Split(clientIDList, ",") + for _, id := range ids { + highDeviceLimitClientIDs[strings.ToLower(strings.TrimSpace(id))] = true + } + } +} + +func hasReachedDeviceLimit(activeDevices int, clientID string) bool { + limit := maxActiveDevices + if highDeviceLimitClientIDs[strings.ToLower(clientID)] { + limit = highMaxActiveDevices + } + return activeDevices >= limit +}