From 50dfd74d880fef9c82b5a65702cf4842dd0d5d0a Mon Sep 17 00:00:00 2001 From: Connor McKelvey Date: Tue, 1 Jun 2021 17:18:26 -0600 Subject: [PATCH] Add support for lease stealing (#78) Fixes #4 Signed-off-by: Connor McKelvey Signed-off-by: Ali Hobbs Co-authored-by: Ali Hobbs Co-authored-by: Ali Hobbs --- HyperMake | 4 +- clientlibrary/checkpoint/checkpointer.go | 14 + .../checkpoint/dynamodb-checkpointer.go | 192 +++++++++- .../checkpoint/dynamodb-checkpointer_test.go | 355 +++++++++++++++++- clientlibrary/config/config.go | 24 ++ clientlibrary/config/config_test.go | 26 ++ clientlibrary/config/kcl-config.go | 21 +- clientlibrary/partition/partition.go | 24 ++ .../worker/polling-shard-consumer.go | 2 +- clientlibrary/worker/worker.go | 115 +++++- support/toolchain/docker/Dockerfile | 2 +- test/lease_stealing_util_test.go | 230 ++++++++++++ test/logger_test.go | 3 +- test/record_processor_test.go | 3 +- test/record_publisher_test.go | 88 ++++- test/worker_custom_test.go | 38 +- test/worker_lease_stealing_test.go | 127 +++++++ test/worker_test.go | 15 +- 18 files changed, 1233 insertions(+), 50 deletions(-) create mode 100644 test/lease_stealing_util_test.go create mode 100644 test/worker_lease_stealing_test.go diff --git a/HyperMake b/HyperMake index f444947..7ca3d06 100644 --- a/HyperMake +++ b/HyperMake @@ -8,8 +8,8 @@ targets: rebuild-toolchain: description: build toolchain image watches: - - support/docker/toolchain - build: support/docker/toolchain + - support/toolchain/docker + build: support/toolchain/docker toolchain: description: placeholder for additional toolchain dependencies diff --git a/clientlibrary/checkpoint/checkpointer.go b/clientlibrary/checkpoint/checkpointer.go index fe91359..4d4ceaa 100644 --- a/clientlibrary/checkpoint/checkpointer.go +++ b/clientlibrary/checkpoint/checkpointer.go @@ -40,9 +40,13 @@ const ( LeaseTimeoutKey = "LeaseTimeout" SequenceNumberKey = "Checkpoint" ParentShardIdKey = "ParentShardId" + ClaimRequestKey = "ClaimRequest" // We've completely processed all records in this shard. ShardEnd = "SHARD_END" + + // ErrShardClaimed is returned when shard is claimed + ErrShardClaimed = "Shard is already claimed by another node" ) type ErrLeaseNotAcquired struct { @@ -72,7 +76,17 @@ type Checkpointer interface { // RemoveLeaseOwner to remove lease owner for the shard entry to make the shard available for reassignment RemoveLeaseOwner(string) error + + // New Lease Stealing Methods + // ListActiveWorkers returns active workers and their shards + ListActiveWorkers(map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) + + // ClaimShard claims a shard for stealing + ClaimShard(*par.ShardStatus, string) error } // ErrSequenceIDNotFound is returned by FetchCheckpoint when no SequenceID is found var ErrSequenceIDNotFound = errors.New("SequenceIDNotFoundForShard") + +// ErrShardNotAssigned is returned by ListActiveWorkers when no AssignedTo is found +var ErrShardNotAssigned = errors.New("AssignedToNotFoundForShard") diff --git a/clientlibrary/checkpoint/dynamodb-checkpointer.go b/clientlibrary/checkpoint/dynamodb-checkpointer.go index dd8dd55..8df5e37 100644 --- a/clientlibrary/checkpoint/dynamodb-checkpointer.go +++ b/clientlibrary/checkpoint/dynamodb-checkpointer.go @@ -28,6 +28,8 @@ package checkpoint import ( + "errors" + "fmt" "time" "github.com/aws/aws-sdk-go/aws" @@ -61,6 +63,7 @@ type DynamoCheckpoint struct { svc dynamodbiface.DynamoDBAPI kclConfig *config.KinesisClientLibConfiguration Retries int + lastLeaseSync time.Time } func NewDynamoCheckpoint(kclConfig *config.KinesisClientLibConfiguration) *DynamoCheckpoint { @@ -124,8 +127,22 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign return err } + isClaimRequestExpired := shard.IsClaimRequestExpired(checkpointer.kclConfig) + + var claimRequest string + if checkpointer.kclConfig.EnableLeaseStealing { + if currentCheckpointClaimRequest, ok := currentCheckpoint[ClaimRequestKey]; ok && currentCheckpointClaimRequest.S != nil { + claimRequest = *currentCheckpointClaimRequest.S + if newAssignTo != claimRequest && !isClaimRequestExpired { + checkpointer.log.Debugf("another worker: %s has a claim on this shard. Not going to renew the lease", claimRequest) + return errors.New(ErrShardClaimed) + } + } + } + assignedVar, assignedToOk := currentCheckpoint[LeaseOwnerKey] leaseVar, leaseTimeoutOk := currentCheckpoint[LeaseTimeoutKey] + var conditionalExpression string var expressionAttributeValues map[string]*dynamodb.AttributeValue @@ -140,8 +157,14 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign return err } - if time.Now().UTC().Before(currentLeaseTimeout) && assignedTo != newAssignTo { - return ErrLeaseNotAcquired{"current lease timeout not yet expired"} + if checkpointer.kclConfig.EnableLeaseStealing { + if time.Now().UTC().Before(currentLeaseTimeout) && assignedTo != newAssignTo && !isClaimRequestExpired { + return ErrLeaseNotAcquired{"current lease timeout not yet expired"} + } + } else { + if time.Now().UTC().Before(currentLeaseTimeout) && assignedTo != newAssignTo { + return ErrLeaseNotAcquired{"current lease timeout not yet expired"} + } } checkpointer.log.Debugf("Attempting to get a lock for shard: %s, leaseTimeout: %s, assignedTo: %s, newAssignedTo: %s", shard.ID, currentLeaseTimeout, assignedTo, newAssignTo) @@ -175,9 +198,21 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign marshalledCheckpoint[ParentShardIdKey] = &dynamodb.AttributeValue{S: aws.String(shard.ParentShardId)} } - if shard.GetCheckpoint() != "" { + if checkpoint := shard.GetCheckpoint(); checkpoint != "" { marshalledCheckpoint[SequenceNumberKey] = &dynamodb.AttributeValue{ - S: aws.String(shard.GetCheckpoint()), + S: aws.String(checkpoint), + } + } + + if checkpointer.kclConfig.EnableLeaseStealing { + if claimRequest != "" && claimRequest == newAssignTo && !isClaimRequestExpired { + if expressionAttributeValues == nil { + expressionAttributeValues = make(map[string]*dynamodb.AttributeValue) + } + conditionalExpression = conditionalExpression + " AND ClaimRequest = :claim_request" + expressionAttributeValues[":claim_request"] = &dynamodb.AttributeValue{ + S: &claimRequest, + } } } @@ -199,7 +234,7 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *par.ShardStatus, newAssign // CheckpointSequence writes a checkpoint at the designated sequence ID func (checkpointer *DynamoCheckpoint) CheckpointSequence(shard *par.ShardStatus) error { - leaseTimeout := shard.LeaseTimeout.UTC().Format(time.RFC3339) + leaseTimeout := shard.GetLeaseTimeout().UTC().Format(time.RFC3339) marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ LeaseKeyKey: { S: aws.String(shard.ID), @@ -208,7 +243,7 @@ func (checkpointer *DynamoCheckpoint) CheckpointSequence(shard *par.ShardStatus) S: aws.String(shard.GetCheckpoint()), }, LeaseOwnerKey: { - S: aws.String(shard.AssignedTo), + S: aws.String(shard.GetLeaseOwner()), }, LeaseTimeoutKey: { S: aws.String(leaseTimeout), @@ -239,6 +274,16 @@ func (checkpointer *DynamoCheckpoint) FetchCheckpoint(shard *par.ShardStatus) er if assignedTo, ok := checkpoint[LeaseOwnerKey]; ok { shard.SetLeaseOwner(aws.StringValue(assignedTo.S)) } + + // Use up-to-date leaseTimeout to avoid ConditionalCheckFailedException when claiming + if leaseTimeout, ok := checkpoint[LeaseTimeoutKey]; ok && leaseTimeout.S != nil { + currentLeaseTimeout, err := time.Parse(time.RFC3339, aws.StringValue(leaseTimeout.S)) + if err != nil { + return err + } + shard.LeaseTimeout = currentLeaseTimeout + } + return nil } @@ -265,6 +310,12 @@ func (checkpointer *DynamoCheckpoint) RemoveLeaseOwner(shardID string) error { }, }, UpdateExpression: aws.String("remove " + LeaseOwnerKey), + ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ + ":assigned_to": { + S: aws.String(checkpointer.kclConfig.WorkerID), + }, + }, + ConditionExpression: aws.String("AssignedTo = :assigned_to"), } _, err := checkpointer.svc.UpdateItem(input) @@ -272,6 +323,135 @@ func (checkpointer *DynamoCheckpoint) RemoveLeaseOwner(shardID string) error { return err } +// ListActiveWorkers returns a map of workers and their shards +func (checkpointer *DynamoCheckpoint) ListActiveWorkers(shardStatus map[string]*par.ShardStatus) (map[string][]*par.ShardStatus, error) { + err := checkpointer.syncLeases(shardStatus) + if err != nil { + return nil, err + } + + workers := map[string][]*par.ShardStatus{} + for _, shard := range shardStatus { + if shard.GetCheckpoint() == ShardEnd { + continue + } + + leaseOwner := shard.GetLeaseOwner() + if leaseOwner == "" { + checkpointer.log.Debugf("Shard Not Assigned Error. ShardID: %s, WorkerID: %s", shard.ID, checkpointer.kclConfig.WorkerID) + return nil, ErrShardNotAssigned + } + if w, ok := workers[leaseOwner]; ok { + workers[leaseOwner] = append(w, shard) + } else { + workers[leaseOwner] = []*par.ShardStatus{shard} + } + } + return workers, nil +} + +// ClaimShard places a claim request on a shard to signal a steal attempt +func (checkpointer *DynamoCheckpoint) ClaimShard(shard *par.ShardStatus, claimID string) error { + err := checkpointer.FetchCheckpoint(shard) + if err != nil && err != ErrSequenceIDNotFound { + return err + } + leaseTimeoutString := shard.GetLeaseTimeout().Format(time.RFC3339) + + conditionalExpression := `ShardID = :id AND LeaseTimeout = :lease_timeout AND attribute_not_exists(ClaimRequest)` + expressionAttributeValues := map[string]*dynamodb.AttributeValue{ + ":id": { + S: aws.String(shard.ID), + }, + ":lease_timeout": { + S: aws.String(leaseTimeoutString), + }, + } + + marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ + LeaseKeyKey: { + S: &shard.ID, + }, + LeaseTimeoutKey: { + S: &leaseTimeoutString, + }, + SequenceNumberKey: { + S: &shard.Checkpoint, + }, + ClaimRequestKey: { + S: &claimID, + }, + } + + if leaseOwner := shard.GetLeaseOwner(); leaseOwner == "" { + conditionalExpression += " AND attribute_not_exists(AssignedTo)" + } else { + marshalledCheckpoint[LeaseOwnerKey] = &dynamodb.AttributeValue{S: &leaseOwner} + conditionalExpression += "AND AssignedTo = :assigned_to" + expressionAttributeValues[":assigned_to"] = &dynamodb.AttributeValue{S: &leaseOwner} + } + + if checkpoint := shard.GetCheckpoint(); checkpoint == "" { + conditionalExpression += " AND attribute_not_exists(Checkpoint)" + } else if checkpoint == ShardEnd { + conditionalExpression += " AND Checkpoint <> :checkpoint" + expressionAttributeValues[":checkpoint"] = &dynamodb.AttributeValue{S: aws.String(ShardEnd)} + } else { + conditionalExpression += " AND Checkpoint = :checkpoint" + expressionAttributeValues[":checkpoint"] = &dynamodb.AttributeValue{S: &checkpoint} + } + + if shard.ParentShardId == "" { + conditionalExpression += " AND attribute_not_exists(ParentShardId)" + } else { + marshalledCheckpoint[ParentShardIdKey] = &dynamodb.AttributeValue{S: aws.String(shard.ParentShardId)} + conditionalExpression += " AND ParentShardId = :parent_shard" + expressionAttributeValues[":parent_shard"] = &dynamodb.AttributeValue{S: &shard.ParentShardId} + } + + return checkpointer.conditionalUpdate(conditionalExpression, expressionAttributeValues, marshalledCheckpoint) +} + +func (checkpointer *DynamoCheckpoint) syncLeases(shardStatus map[string]*par.ShardStatus) error { + log := checkpointer.kclConfig.Logger + + if (checkpointer.lastLeaseSync.Add(time.Duration(checkpointer.kclConfig.LeaseSyncingTimeIntervalMillis) * time.Millisecond)).After(time.Now()) { + return nil + } + + checkpointer.lastLeaseSync = time.Now() + input := &dynamodb.ScanInput{ + ProjectionExpression: aws.String(fmt.Sprintf("%s,%s,%s", LeaseKeyKey, LeaseOwnerKey, SequenceNumberKey)), + Select: aws.String("SPECIFIC_ATTRIBUTES"), + TableName: aws.String(checkpointer.kclConfig.TableName), + } + + err := checkpointer.svc.ScanPages(input, + func(pages *dynamodb.ScanOutput, lastPage bool) bool { + results := pages.Items + for _, result := range results { + shardId, foundShardId := result[LeaseKeyKey] + assignedTo, foundAssignedTo := result[LeaseOwnerKey] + checkpoint, foundCheckpoint := result[SequenceNumberKey] + if !foundShardId || !foundAssignedTo || !foundCheckpoint { + continue + } + if shard, ok := shardStatus[aws.StringValue(shardId.S)]; ok { + shard.SetLeaseOwner(aws.StringValue(assignedTo.S)) + shard.SetCheckpoint(aws.StringValue(checkpoint.S)) + } + } + return !lastPage + }) + + if err != nil { + log.Debugf("Error performing SyncLeases. Error: %+v ", err) + return err + } + log.Debugf("Lease sync completed. Next lease sync will occur in %s", time.Duration(checkpointer.kclConfig.LeaseSyncingTimeIntervalMillis)*time.Millisecond) + return nil +} + func (checkpointer *DynamoCheckpoint) createTable() error { input := &dynamodb.CreateTableInput{ AttributeDefinitions: []*dynamodb.AttributeDefinition{ diff --git a/clientlibrary/checkpoint/dynamodb-checkpointer_test.go b/clientlibrary/checkpoint/dynamodb-checkpointer_test.go index 2217b0e..38da0b3 100644 --- a/clientlibrary/checkpoint/dynamodb-checkpointer_test.go +++ b/clientlibrary/checkpoint/dynamodb-checkpointer_test.go @@ -85,6 +85,7 @@ func TestGetLeaseNotAquired(t *testing.T) { Checkpoint: "", Mux: &sync.RWMutex{}, }, "ijkl-mnop") + if err == nil || !errors.As(err, &ErrLeaseNotAcquired{}) { t.Errorf("Got a lease when it was already held by abcd-efgh: %s", err) } @@ -102,16 +103,16 @@ func TestGetLeaseAquired(t *testing.T) { checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) checkpoint.Init() marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ - "ShardID": { + LeaseKeyKey: { S: aws.String("0001"), }, - "AssignedTo": { + LeaseOwnerKey: { S: aws.String("abcd-efgh"), }, - "LeaseTimeout": { + LeaseTimeoutKey: { S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)), }, - "SequenceID": { + SequenceNumberKey: { S: aws.String("deadbeef"), }, } @@ -156,10 +157,221 @@ func TestGetLeaseAquired(t *testing.T) { assert.Equal(t, "", status.GetLeaseOwner()) } +func TestGetLeaseShardClaimed(t *testing.T) { + leaseTimeout := time.Now().Add(-100 * time.Second).UTC() + svc := &mockDynamoDB{ + tableExist: true, + item: map[string]*dynamodb.AttributeValue{ + ClaimRequestKey: {S: aws.String("ijkl-mnop")}, + LeaseTimeoutKey: {S: aws.String(leaseTimeout.Format(time.RFC3339))}, + }, + } + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithMaxLeasesForWorker(1). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(300000). + WithLeaseStealing(true) + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + checkpoint.Init() + err := checkpoint.GetLease(&par.ShardStatus{ + ID: "0001", + Checkpoint: "", + LeaseTimeout: leaseTimeout, + Mux: &sync.RWMutex{}, + }, "abcd-efgh") + if err == nil || err.Error() != ErrShardClaimed { + t.Errorf("Got a lease when it was already claimed by by ijkl-mnop: %s", err) + } + + err = checkpoint.GetLease(&par.ShardStatus{ + ID: "0001", + Checkpoint: "", + LeaseTimeout: leaseTimeout, + Mux: &sync.RWMutex{}, + }, "ijkl-mnop") + if err != nil { + t.Errorf("Error getting lease %s", err) + } +} + +func TestGetLeaseClaimRequestExpiredOwner(t *testing.T) { + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithMaxLeasesForWorker(1). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(300000). + WithLeaseStealing(true) + + // Not expired + leaseTimeout := time.Now(). + Add(-time.Duration(kclConfig.LeaseStealingClaimTimeoutMillis) * time.Millisecond). + Add(1 * time.Second). + UTC() + + svc := &mockDynamoDB{ + tableExist: true, + item: map[string]*dynamodb.AttributeValue{ + LeaseOwnerKey: {S: aws.String("abcd-efgh")}, + ClaimRequestKey: {S: aws.String("ijkl-mnop")}, + LeaseTimeoutKey: {S: aws.String(leaseTimeout.Format(time.RFC3339))}, + }, + } + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + checkpoint.Init() + err := checkpoint.GetLease(&par.ShardStatus{ + ID: "0001", + Checkpoint: "", + LeaseTimeout: leaseTimeout, + Mux: &sync.RWMutex{}, + }, "abcd-efgh") + if err == nil || err.Error() != ErrShardClaimed { + t.Errorf("Got a lease when it was already claimed by ijkl-mnop: %s", err) + } +} + +func TestGetLeaseClaimRequestExpiredClaimer(t *testing.T) { + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithMaxLeasesForWorker(1). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(300000). + WithLeaseStealing(true) + + // Not expired + leaseTimeout := time.Now(). + Add(-time.Duration(kclConfig.LeaseStealingClaimTimeoutMillis) * time.Millisecond). + Add(121 * time.Second). + UTC() + + svc := &mockDynamoDB{ + tableExist: true, + item: map[string]*dynamodb.AttributeValue{ + LeaseOwnerKey: {S: aws.String("abcd-efgh")}, + ClaimRequestKey: {S: aws.String("ijkl-mnop")}, + LeaseTimeoutKey: {S: aws.String(leaseTimeout.Format(time.RFC3339))}, + }, + } + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + checkpoint.Init() + err := checkpoint.GetLease(&par.ShardStatus{ + ID: "0001", + Checkpoint: "", + LeaseTimeout: leaseTimeout, + Mux: &sync.RWMutex{}, + }, "ijkl-mnop") + if err == nil || !errors.As(err, &ErrLeaseNotAcquired{}) { + t.Errorf("Got a lease when it was already claimed by ijkl-mnop: %s", err) + } +} + +func TestFetchCheckpointWithStealing(t *testing.T) { + future := time.Now().AddDate(0, 1, 0) + + svc := &mockDynamoDB{ + tableExist: true, + item: map[string]*dynamodb.AttributeValue{ + SequenceNumberKey: {S: aws.String("deadbeef")}, + LeaseOwnerKey: {S: aws.String("abcd-efgh")}, + LeaseTimeoutKey: { + S: aws.String(future.Format(time.RFC3339)), + }, + }, + } + + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithMaxLeasesForWorker(1). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(300000). + WithLeaseStealing(true) + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + checkpoint.Init() + + status := &par.ShardStatus{ + ID: "0001", + Checkpoint: "", + LeaseTimeout: time.Now(), + Mux: &sync.RWMutex{}, + } + + checkpoint.FetchCheckpoint(status) + + leaseTimeout, _ := time.Parse(time.RFC3339, *svc.item[LeaseTimeoutKey].S) + assert.Equal(t, leaseTimeout, status.LeaseTimeout) +} + +func TestGetLeaseConditional(t *testing.T) { + svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}} + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithMaxLeasesForWorker(1). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(300000). + WithLeaseStealing(true) + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + checkpoint.Init() + marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ + LeaseKeyKey: { + S: aws.String("0001"), + }, + LeaseOwnerKey: { + S: aws.String("abcd-efgh"), + }, + LeaseTimeoutKey: { + S: aws.String(time.Now().Add(-1 * time.Second).UTC().Format(time.RFC3339)), + }, + SequenceNumberKey: { + S: aws.String("deadbeef"), + }, + ClaimRequestKey: { + S: aws.String("ijkl-mnop"), + }, + } + input := &dynamodb.PutItemInput{ + TableName: aws.String("TableName"), + Item: marshalledCheckpoint, + } + checkpoint.svc.PutItem(input) + shard := &par.ShardStatus{ + ID: "0001", + Checkpoint: "deadbeef", + ClaimRequest: "ijkl-mnop", + Mux: &sync.RWMutex{}, + } + err := checkpoint.FetchCheckpoint(shard) + if err != nil { + t.Errorf("Could not fetch checkpoint %s", err) + } + + err = checkpoint.GetLease(shard, "ijkl-mnop") + if err != nil { + t.Errorf("Lease not aquired after timeout %s", err) + } + assert.Equal(t, *svc.expressionAttributeValues[":claim_request"].S, "ijkl-mnop") + assert.Contains(t, svc.conditionalExpression, " AND ClaimRequest = :claim_request") +} + type mockDynamoDB struct { dynamodbiface.DynamoDBAPI - tableExist bool - item map[string]*dynamodb.AttributeValue + tableExist bool + item map[string]*dynamodb.AttributeValue + conditionalExpression string + expressionAttributeValues map[string]*dynamodb.AttributeValue +} + +func (m *mockDynamoDB) ScanPages(*dynamodb.ScanInput, func(*dynamodb.ScanOutput, bool) bool) error { + return nil } func (m *mockDynamoDB) DescribeTable(*dynamodb.DescribeTableInput) (*dynamodb.DescribeTableOutput, error) { @@ -192,6 +404,16 @@ func (m *mockDynamoDB) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemO m.item[ParentShardIdKey] = parent } + if claimRequest, ok := item[ClaimRequestKey]; ok { + m.item[ClaimRequestKey] = claimRequest + } + + if input.ConditionExpression != nil { + m.conditionalExpression = *input.ConditionExpression + } + + m.expressionAttributeValues = input.ExpressionAttributeValues + return nil, nil } @@ -214,3 +436,124 @@ func (m *mockDynamoDB) UpdateItem(input *dynamodb.UpdateItemInput) (*dynamodb.Up func (m *mockDynamoDB) CreateTable(input *dynamodb.CreateTableInput) (*dynamodb.CreateTableOutput, error) { return &dynamodb.CreateTableOutput{}, nil } + +func TestListActiveWorkers(t *testing.T) { + svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}} + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithLeaseStealing(true) + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + err := checkpoint.Init() + if err != nil { + t.Errorf("Checkpoint initialization failed: %+v", err) + } + + shardStatus := map[string]*par.ShardStatus{ + "0000": {ID: "0000", AssignedTo: "worker_1", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0001": {ID: "0001", AssignedTo: "worker_2", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0002": {ID: "0002", AssignedTo: "worker_4", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0003": {ID: "0003", AssignedTo: "worker_0", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0004": {ID: "0004", AssignedTo: "worker_1", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0005": {ID: "0005", AssignedTo: "worker_3", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0006": {ID: "0006", AssignedTo: "worker_3", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0007": {ID: "0007", AssignedTo: "worker_0", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0008": {ID: "0008", AssignedTo: "worker_4", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0009": {ID: "0009", AssignedTo: "worker_2", Checkpoint: "", Mux: &sync.RWMutex{}}, + "0010": {ID: "0010", AssignedTo: "worker_0", Checkpoint: ShardEnd, Mux: &sync.RWMutex{}}, + } + + workers, err := checkpoint.ListActiveWorkers(shardStatus) + if err != nil { + t.Error(err) + } + + for workerID, shards := range workers { + assert.Equal(t, 2, len(shards)) + for _, shard := range shards { + assert.Equal(t, workerID, shard.AssignedTo) + } + } +} + +func TestListActiveWorkersErrShardNotAssigned(t *testing.T) { + svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}} + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithLeaseStealing(true) + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + err := checkpoint.Init() + if err != nil { + t.Errorf("Checkpoint initialization failed: %+v", err) + } + + shardStatus := map[string]*par.ShardStatus{ + "0000": {ID: "0000", Mux: &sync.RWMutex{}}, + } + + _, err = checkpoint.ListActiveWorkers(shardStatus) + if err != ErrShardNotAssigned { + t.Error("Expected ErrShardNotAssigned when shard is missing AssignedTo value") + } +} + +func TestClaimShard(t *testing.T) { + svc := &mockDynamoDB{tableExist: true, item: map[string]*dynamodb.AttributeValue{}} + kclConfig := cfg.NewKinesisClientLibConfig("appName", "test", "us-west-2", "abc"). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithMaxLeasesForWorker(1). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(300000). + WithLeaseStealing(true) + + checkpoint := NewDynamoCheckpoint(kclConfig).WithDynamoDB(svc) + checkpoint.Init() + + marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ + "ShardID": { + S: aws.String("0001"), + }, + "AssignedTo": { + S: aws.String("abcd-efgh"), + }, + "LeaseTimeout": { + S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)), + }, + "Checkpoint": { + S: aws.String("deadbeef"), + }, + } + input := &dynamodb.PutItemInput{ + TableName: aws.String("TableName"), + Item: marshalledCheckpoint, + } + checkpoint.svc.PutItem(input) + shard := &par.ShardStatus{ + ID: "0001", + Checkpoint: "deadbeef", + Mux: &sync.RWMutex{}, + } + + err := checkpoint.ClaimShard(shard, "ijkl-mnop") + if err != nil { + t.Errorf("Shard not claimed %s", err) + } + + claimRequest, ok := svc.item[ClaimRequestKey] + if !ok { + t.Error("Expected claimRequest to be set by ClaimShard") + } else if *claimRequest.S != "ijkl-mnop" { + t.Errorf("Expected checkpoint to be ijkl-mnop. Got '%s'", *claimRequest.S) + } + + status := &par.ShardStatus{ + ID: shard.ID, + Mux: &sync.RWMutex{}, + } + checkpoint.FetchCheckpoint(status) + + // asiggnedTo, checkpointer, and parent shard id should be the same + assert.Equal(t, shard.AssignedTo, status.AssignedTo) + assert.Equal(t, shard.Checkpoint, status.Checkpoint) + assert.Equal(t, shard.ParentShardId, status.ParentShardId) +} diff --git a/clientlibrary/config/config.go b/clientlibrary/config/config.go index f8102eb..9f3f002 100644 --- a/clientlibrary/config/config.go +++ b/clientlibrary/config/config.go @@ -122,6 +122,18 @@ const ( // The amount of milliseconds to wait before graceful shutdown forcefully terminates. DefaultShutdownGraceMillis = 5000 + + // Lease stealing defaults to false for backwards compatibility. + DefaultEnableLeaseStealing = false + + // Interval between rebalance tasks defaults to 5 seconds. + DefaultLeaseStealingIntervalMillis = 5000 + + // Number of milliseconds to wait before another worker can aquire a claimed shard + DefaultLeaseStealingClaimTimeoutMillis = 120000 + + // Number of milliseconds to wait before syncing with lease table (dynamodDB) + DefaultLeaseSyncingIntervalMillis = 60000 ) type ( @@ -257,6 +269,18 @@ type ( // MonitoringService publishes per worker-scoped metrics. MonitoringService metrics.MonitoringService + + // EnableLeaseStealing turns on lease stealing + EnableLeaseStealing bool + + // LeaseStealingIntervalMillis The number of milliseconds between rebalance tasks + LeaseStealingIntervalMillis int + + // LeaseStealingClaimTimeoutMillis The number of milliseconds to wait before another worker can aquire a claimed shard + LeaseStealingClaimTimeoutMillis int + + // LeaseSyncingTimeInterval The number of milliseconds to wait before syncing with lease table (dynamoDB) + LeaseSyncingTimeIntervalMillis int } ) diff --git a/clientlibrary/config/config_test.go b/clientlibrary/config/config_test.go index c02dfab..1785e91 100644 --- a/clientlibrary/config/config_test.go +++ b/clientlibrary/config/config_test.go @@ -39,9 +39,35 @@ func TestConfig(t *testing.T) { assert.Equal(t, "appName", kclConfig.ApplicationName) assert.Equal(t, 500, kclConfig.FailoverTimeMillis) assert.Equal(t, 10, kclConfig.TaskBackoffTimeMillis) + assert.True(t, kclConfig.EnableEnhancedFanOutConsumer) assert.Equal(t, "fan-out-consumer", kclConfig.EnhancedFanOutConsumerName) + assert.Equal(t, false, kclConfig.EnableLeaseStealing) + assert.Equal(t, 5000, kclConfig.LeaseStealingIntervalMillis) + + contextLogger := kclConfig.Logger.WithFields(logger.Fields{"key1": "value1"}) + contextLogger.Debugf("Starting with default logger") + contextLogger.Infof("Default logger is awesome") +} + +func TestConfigLeaseStealing(t *testing.T) { + kclConfig := NewKinesisClientLibConfig("appName", "StreamName", "us-west-2", "workerId"). + WithFailoverTimeMillis(500). + WithMaxRecords(100). + WithInitialPositionInStream(TRIM_HORIZON). + WithIdleTimeBetweenReadsInMillis(20). + WithCallProcessRecordsEvenForEmptyRecordList(true). + WithTaskBackoffTimeMillis(10). + WithLeaseStealing(true). + WithLeaseStealingIntervalMillis(10000) + + assert.Equal(t, "appName", kclConfig.ApplicationName) + assert.Equal(t, 500, kclConfig.FailoverTimeMillis) + assert.Equal(t, 10, kclConfig.TaskBackoffTimeMillis) + assert.Equal(t, true, kclConfig.EnableLeaseStealing) + assert.Equal(t, 10000, kclConfig.LeaseStealingIntervalMillis) + contextLogger := kclConfig.Logger.WithFields(logger.Fields{"key1": "value1"}) contextLogger.Debugf("Starting with default logger") contextLogger.Infof("Default logger is awesome") diff --git a/clientlibrary/config/kcl-config.go b/clientlibrary/config/kcl-config.go index 91f39b7..a831e88 100644 --- a/clientlibrary/config/kcl-config.go +++ b/clientlibrary/config/kcl-config.go @@ -95,7 +95,11 @@ func NewKinesisClientLibConfigWithCredentials(applicationName, streamName, regio InitialLeaseTableReadCapacity: DefaultInitialLeaseTableReadCapacity, InitialLeaseTableWriteCapacity: DefaultInitialLeaseTableWriteCapacity, SkipShardSyncAtWorkerInitializationIfLeasesExist: DefaultSkipShardSyncAtStartupIfLeasesExist, - Logger: logger.GetDefaultLogger(), + EnableLeaseStealing: DefaultEnableLeaseStealing, + LeaseStealingIntervalMillis: DefaultLeaseStealingIntervalMillis, + LeaseStealingClaimTimeoutMillis: DefaultLeaseStealingClaimTimeoutMillis, + LeaseSyncingTimeIntervalMillis: DefaultLeaseSyncingIntervalMillis, + Logger: logger.GetDefaultLogger(), } } @@ -241,3 +245,18 @@ func (c *KinesisClientLibConfiguration) WithEnhancedFanOutConsumerARN(consumerAR c.EnableEnhancedFanOutConsumer = true return c } + +func (c *KinesisClientLibConfiguration) WithLeaseStealing(enableLeaseStealing bool) *KinesisClientLibConfiguration { + c.EnableLeaseStealing = enableLeaseStealing + return c +} + +func (c *KinesisClientLibConfiguration) WithLeaseStealingIntervalMillis(leaseStealingIntervalMillis int) *KinesisClientLibConfiguration { + c.LeaseStealingIntervalMillis = leaseStealingIntervalMillis + return c +} + +func (c *KinesisClientLibConfiguration) WithLeaseSyncingIntervalMillis(leaseSyncingIntervalMillis int) *KinesisClientLibConfiguration { + c.LeaseSyncingTimeIntervalMillis = leaseSyncingIntervalMillis + return c +} diff --git a/clientlibrary/partition/partition.go b/clientlibrary/partition/partition.go index 955bf08..b3f287f 100644 --- a/clientlibrary/partition/partition.go +++ b/clientlibrary/partition/partition.go @@ -30,6 +30,8 @@ package worker import ( "sync" "time" + + "github.com/vmware/vmware-go-kcl/clientlibrary/config" ) type ShardStatus struct { @@ -43,6 +45,7 @@ type ShardStatus struct { StartingSequenceNumber string // child shard doesn't have end sequence number EndingSequenceNumber string + ClaimRequest string } func (ss *ShardStatus) GetLeaseOwner() string { @@ -68,3 +71,24 @@ func (ss *ShardStatus) SetCheckpoint(c string) { defer ss.Mux.Unlock() ss.Checkpoint = c } + +func (ss *ShardStatus) GetLeaseTimeout() time.Time { + ss.Mux.Lock() + defer ss.Mux.Unlock() + return ss.LeaseTimeout +} + +func (ss *ShardStatus) SetLeaseTimeout(timeout time.Time) { + ss.Mux.Lock() + defer ss.Mux.Unlock() + ss.LeaseTimeout = timeout +} + +func (ss *ShardStatus) IsClaimRequestExpired(kclConfig *config.KinesisClientLibConfiguration) bool { + if leaseTimeout := ss.GetLeaseTimeout(); leaseTimeout.IsZero() { + return false + } else { + return leaseTimeout. + Before(time.Now().UTC().Add(time.Duration(-kclConfig.LeaseStealingClaimTimeoutMillis) * time.Millisecond)) + } +} diff --git a/clientlibrary/worker/polling-shard-consumer.go b/clientlibrary/worker/polling-shard-consumer.go index 27e5c80..90371b0 100644 --- a/clientlibrary/worker/polling-shard-consumer.go +++ b/clientlibrary/worker/polling-shard-consumer.go @@ -103,7 +103,7 @@ func (sc *PollingShardConsumer) getRecords() error { retriedErrors := 0 for { - if time.Now().UTC().After(sc.shard.LeaseTimeout.Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { + if time.Now().UTC().After(sc.shard.GetLeaseTimeout().Add(-time.Duration(sc.kclConfig.LeaseRefreshPeriodMillis) * time.Millisecond)) { log.Debugf("Refreshing lease on shard: %s for worker: %s", sc.shard.ID, sc.consumerID) err = sc.checkpointer.GetLease(sc.shard, sc.consumerID) if err != nil { diff --git a/clientlibrary/worker/worker.go b/clientlibrary/worker/worker.go index 0ab4d17..fb2dd4a 100644 --- a/clientlibrary/worker/worker.go +++ b/clientlibrary/worker/worker.go @@ -68,7 +68,8 @@ type Worker struct { rng *rand.Rand - shardStatus map[string]*par.ShardStatus + shardStatus map[string]*par.ShardStatus + shardStealInProgress bool } // NewWorker constructs a Worker instance for processing Kinesis stream data. @@ -271,7 +272,7 @@ func (w *Worker) eventLoop() { log.Infof("Found %d shards", foundShards) } - // Count the number of leases hold by this worker excluding the processed shard + // Count the number of leases held by this worker excluding the processed shard counter := 0 for _, shard := range w.shardStatus { if shard.GetLeaseOwner() == w.workerID && shard.GetCheckpoint() != chk.ShardEnd { @@ -302,6 +303,20 @@ func (w *Worker) eventLoop() { continue } + var stealShard bool + if w.kclConfig.EnableLeaseStealing && shard.ClaimRequest != "" { + upcomingStealingInterval := time.Now().UTC().Add(time.Duration(w.kclConfig.LeaseStealingIntervalMillis) * time.Millisecond) + if shard.GetLeaseTimeout().Before(upcomingStealingInterval) && !shard.IsClaimRequestExpired(w.kclConfig) { + if shard.ClaimRequest == w.workerID { + stealShard = true + log.Debugf("Stealing shard: %s", shard.ID) + } else { + log.Debugf("Shard being stolen: %s", shard.ID) + continue + } + } + } + err = w.checkpointer.GetLease(shard, w.workerID) if err != nil { // cannot get lease on the shard @@ -311,6 +326,11 @@ func (w *Worker) eventLoop() { continue } + if stealShard { + log.Debugf("Successfully stole shard: %+v", shard.ID) + w.shardStealInProgress = false + } + // log metrics on got lease w.mService.LeaseGained(shard.ID) w.waitGroup.Add(1) @@ -325,6 +345,13 @@ func (w *Worker) eventLoop() { } } + if w.kclConfig.EnableLeaseStealing { + err = w.rebalance() + if err != nil { + log.Warnf("Error in rebalance: %+v", err) + } + } + select { case <-*w.stop: log.Infof("Shutting down...") @@ -335,6 +362,90 @@ func (w *Worker) eventLoop() { } } +func (w *Worker) rebalance() error { + log := w.kclConfig.Logger + + workers, err := w.checkpointer.ListActiveWorkers(w.shardStatus) + if err != nil { + log.Debugf("Error listing workers. workerID: %s. Error: %+v ", w.workerID, err) + return err + } + + // Only attempt to steal one shard at at time, to allow for linear convergence + if w.shardStealInProgress { + shardInfo := make(map[string]bool) + err := w.getShardIDs("", shardInfo) + if err != nil { + return err + } + for _, shard := range w.shardStatus { + if shard.ClaimRequest != "" && shard.ClaimRequest == w.workerID { + log.Debugf("Steal in progress. workerID: %s", w.workerID) + return nil + } + // Our shard steal was stomped on by a Checkpoint. + // We could deal with that, but instead just try again + w.shardStealInProgress = false + } + } + + var numShards int + for _, shards := range workers { + numShards += len(shards) + } + + numWorkers := len(workers) + + // 1:1 shards to workers is optimal, so we cannot possibly rebalance + if numWorkers >= numShards { + log.Debugf("Optimal shard allocation, not stealing any shards. workerID: %s, %v > %v. ", w.workerID, numWorkers, numShards) + return nil + } + + currentShards, ok := workers[w.workerID] + var numCurrentShards int + if !ok { + numCurrentShards = 0 + numWorkers++ + } else { + numCurrentShards = len(currentShards) + } + + optimalShards := numShards / numWorkers + + // We have more than or equal optimal shards, so no rebalancing can take place + if numCurrentShards >= optimalShards || numCurrentShards == w.kclConfig.MaxLeasesForWorker { + log.Debugf("We have enough shards, not attempting to steal any. workerID: %s", w.workerID) + return nil + } + maxShards := int(optimalShards) + var workerSteal string + for worker, shards := range workers { + if worker != w.workerID && len(shards) > maxShards { + workerSteal = worker + maxShards = len(shards) + } + } + // Not all shards are allocated so fallback to default shard allocation mechanisms + if workerSteal == "" { + log.Infof("Not all shards are allocated, not stealing any. workerID: %s", w.workerID) + return nil + } + + // Steal a random shard from the worker with the most shards + w.shardStealInProgress = true + randIndex := rand.Intn(len(workers[workerSteal])) + shardToSteal := workers[workerSteal][randIndex] + log.Debugf("Stealing shard %s from %s", shardToSteal, workerSteal) + + err = w.checkpointer.ClaimShard(w.shardStatus[shardToSteal.ID], w.workerID) + if err != nil { + w.shardStealInProgress = false + return err + } + return nil +} + // List all shards and store them into shardStatus table // If shard has been removed, need to exclude it from cached shard status. func (w *Worker) getShardIDs(nextToken string, shardInfo map[string]bool) error { diff --git a/support/toolchain/docker/Dockerfile b/support/toolchain/docker/Dockerfile index 1e66efe..47a4528 100644 --- a/support/toolchain/docker/Dockerfile +++ b/support/toolchain/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.12 +FROM golang:1.13 ENV PATH /go/bin:/src/bin:/root/go/bin:/usr/local/go/bin:$PATH ENV GOPATH /go:/src RUN go get -v github.com/alecthomas/gometalinter && \ diff --git a/test/lease_stealing_util_test.go b/test/lease_stealing_util_test.go new file mode 100644 index 0000000..21b8ab3 --- /dev/null +++ b/test/lease_stealing_util_test.go @@ -0,0 +1,230 @@ +package test + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/stretchr/testify/assert" + chk "github.com/vmware/vmware-go-kcl/clientlibrary/checkpoint" + cfg "github.com/vmware/vmware-go-kcl/clientlibrary/config" + wk "github.com/vmware/vmware-go-kcl/clientlibrary/worker" +) + +type LeaseStealingTest struct { + t *testing.T + config *TestClusterConfig + cluster *TestCluster + kc kinesisiface.KinesisAPI + dc dynamodbiface.DynamoDBAPI + + backOffSeconds int + maxRetries int +} + +func NewLeaseStealingTest(t *testing.T, config *TestClusterConfig, workerFactory TestWorkerFactory) *LeaseStealingTest { + cluster := NewTestCluster(t, config, workerFactory) + clientConfig := cluster.workerFactory.CreateKCLConfig("test-client", config) + return &LeaseStealingTest{ + t: t, + config: config, + cluster: cluster, + kc: NewKinesisClient(t, config.regionName, clientConfig.KinesisEndpoint, clientConfig.KinesisCredentials), + dc: NewDynamoDBClient(t, config.regionName, clientConfig.DynamoDBEndpoint, clientConfig.KinesisCredentials), + backOffSeconds: 5, + maxRetries: 60, + } +} + +func (lst *LeaseStealingTest) WithBackoffSeconds(backoff int) *LeaseStealingTest { + lst.backOffSeconds = backoff + return lst +} + +func (lst *LeaseStealingTest) WithMaxRetries(retries int) *LeaseStealingTest { + lst.maxRetries = retries + return lst +} + +func (lst *LeaseStealingTest) publishSomeData() (stop func()) { + done := make(chan int) + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + ticker := time.NewTicker(500 * time.Millisecond) + defer wg.Done() + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + lst.t.Log("Coninuously publishing records") + publishSomeData(lst.t, lst.kc) + } + } + }() + + return func() { + close(done) + wg.Wait() + } +} + +func (lst *LeaseStealingTest) getShardCountByWorker() map[string]int { + input := &dynamodb.ScanInput{ + TableName: aws.String(lst.config.appName), + } + + shardsByWorker := map[string]map[string]bool{} + err := lst.dc.ScanPages(input, func(out *dynamodb.ScanOutput, lastPage bool) bool { + for _, result := range out.Items { + if shardID, ok := result[chk.LeaseKeyKey]; !ok { + continue + } else if assignedTo, ok := result[chk.LeaseOwnerKey]; !ok { + continue + } else { + if _, ok := shardsByWorker[*assignedTo.S]; !ok { + shardsByWorker[*assignedTo.S] = map[string]bool{} + } + shardsByWorker[*assignedTo.S][*shardID.S] = true + } + } + return !lastPage + }) + assert.Nil(lst.t, err) + + shardCountByWorker := map[string]int{} + for worker, shards := range shardsByWorker { + shardCountByWorker[worker] = len(shards) + } + return shardCountByWorker +} + +type LeaseStealingAssertions struct { + expectedLeasesForIntialWorker int + expectedLeasesPerWorker int +} + +func (lst *LeaseStealingTest) Run(assertions LeaseStealingAssertions) { + // Publish records onto stream thoughtout the entire duration of the test + stop := lst.publishSomeData() + defer stop() + + // Start worker 1 + worker1, _ := lst.cluster.SpawnWorker() + + // Wait until the above worker has all leases + var worker1ShardCount int + for i := 0; i < lst.maxRetries; i++ { + time.Sleep(time.Duration(lst.backOffSeconds) * time.Second) + + shardCountByWorker := lst.getShardCountByWorker() + if shardCount, ok := shardCountByWorker[worker1]; ok && shardCount == assertions.expectedLeasesForIntialWorker { + worker1ShardCount = shardCount + break + } + } + + // Assert correct number of leases + assert.Equal(lst.t, assertions.expectedLeasesForIntialWorker, worker1ShardCount) + + // Spawn Remaining Wokers + for i := 0; i < lst.config.numWorkers-1; i++ { + lst.cluster.SpawnWorker() + } + + // Wait For Rebalance + var shardCountByWorker map[string]int + for i := 0; i < lst.maxRetries; i++ { + time.Sleep(time.Duration(lst.backOffSeconds) * time.Second) + + shardCountByWorker = lst.getShardCountByWorker() + + correctCount := true + for _, count := range shardCountByWorker { + if count != assertions.expectedLeasesPerWorker { + correctCount = false + } + } + + if correctCount { + break + } + } + + // Assert Rebalanced + assert.Greater(lst.t, len(shardCountByWorker), 0) + for _, count := range shardCountByWorker { + assert.Equal(lst.t, assertions.expectedLeasesPerWorker, count) + } + + // Shutdown Workers + time.Sleep(10 * time.Second) + lst.cluster.Shutdown() +} + +type TestWorkerFactory interface { + CreateWorker(workerID string, kclConfig *cfg.KinesisClientLibConfiguration) *wk.Worker + CreateKCLConfig(workerID string, config *TestClusterConfig) *cfg.KinesisClientLibConfiguration +} + +type TestClusterConfig struct { + numShards int + numWorkers int + + appName string + streamName string + regionName string + workerIDTemplate string +} + +type TestCluster struct { + t *testing.T + config *TestClusterConfig + workerFactory TestWorkerFactory + workerIDs []string + workers map[string]*wk.Worker +} + +func NewTestCluster(t *testing.T, config *TestClusterConfig, workerFactory TestWorkerFactory) *TestCluster { + return &TestCluster{ + t: t, + config: config, + workerFactory: workerFactory, + workerIDs: make([]string, 0), + workers: make(map[string]*wk.Worker), + } +} + +func (tc *TestCluster) addWorker(workerID string, config *cfg.KinesisClientLibConfiguration) *wk.Worker { + worker := tc.workerFactory.CreateWorker(workerID, config) + tc.workerIDs = append(tc.workerIDs, workerID) + tc.workers[workerID] = worker + return worker +} + +func (tc *TestCluster) SpawnWorker() (string, *wk.Worker) { + id := len(tc.workers) + workerID := fmt.Sprintf(tc.config.workerIDTemplate, id) + + config := tc.workerFactory.CreateKCLConfig(workerID, tc.config) + worker := tc.addWorker(workerID, config) + + err := worker.Start() + assert.Nil(tc.t, err) + return workerID, worker +} + +func (tc *TestCluster) Shutdown() { + for workerID, worker := range tc.workers { + tc.t.Logf("Shutting down worker: %v", workerID) + worker.Shutdown() + } +} diff --git a/test/logger_test.go b/test/logger_test.go index 2d63124..f5db877 100644 --- a/test/logger_test.go +++ b/test/logger_test.go @@ -23,9 +23,10 @@ package test import ( "github.com/stretchr/testify/assert" + "testing" + "github.com/sirupsen/logrus" "go.uber.org/zap" - "testing" "github.com/vmware/vmware-go-kcl/logger" zaplogger "github.com/vmware/vmware-go-kcl/logger/zap" diff --git a/test/record_processor_test.go b/test/record_processor_test.go index 31a8556..4f36266 100644 --- a/test/record_processor_test.go +++ b/test/record_processor_test.go @@ -19,10 +19,11 @@ package test import ( + "testing" + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" kc "github.com/vmware/vmware-go-kcl/clientlibrary/interfaces" - "testing" ) // Record processor factory is used to create RecordProcessor diff --git a/test/record_publisher_test.go b/test/record_publisher_test.go index f948fc1..baaac57 100644 --- a/test/record_publisher_test.go +++ b/test/record_publisher_test.go @@ -21,9 +21,13 @@ package test import ( "crypto/md5" "fmt" + "sync" + "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" rec "github.com/awslabs/kinesis-aggregation/go/records" @@ -50,12 +54,79 @@ func NewKinesisClient(t *testing.T, regionName, endpoint string, credentials *cr return kinesis.New(s) } +// NewDynamoDBClient to create a Kinesis Client. +func NewDynamoDBClient(t *testing.T, regionName, endpoint string, credentials *credentials.Credentials) *dynamodb.DynamoDB { + s, err := session.NewSession(&aws.Config{ + Region: aws.String(regionName), + Endpoint: aws.String(endpoint), + Credentials: credentials, + }) + + if err != nil { + // no need to move forward + t.Fatalf("Failed in getting DynamoDB session for creating Worker: %+v", err) + } + return dynamodb.New(s) +} + +func continuouslyPublishSomeData(t *testing.T, kc kinesisiface.KinesisAPI) func() { + shards := []*kinesis.Shard{} + var nextToken *string + for { + out, err := kc.ListShards(&kinesis.ListShardsInput{ + StreamName: aws.String(streamName), + NextToken: nextToken, + }) + if err != nil { + t.Errorf("Error in ListShards. %+v", err) + } + + shards = append(shards, out.Shards...) + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + done := make(chan int) + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + ticker := time.NewTicker(500 * time.Millisecond) + for { + select { + case <-done: + return + case <-ticker.C: + publishToAllShards(t, kc, shards) + publishSomeData(t, kc) + } + } + }() + + return func() { + close(done) + wg.Wait() + } +} + +func publishToAllShards(t *testing.T, kc kinesisiface.KinesisAPI, shards []*kinesis.Shard) { + // Put records to all shards + for i := 0; i < 10; i++ { + for _, shard := range shards { + publishRecord(t, kc, shard.HashKeyRange.StartingHashKey) + } + } +} + // publishSomeData to put some records into Kinesis stream func publishSomeData(t *testing.T, kc kinesisiface.KinesisAPI) { // Put some data into stream. t.Log("Putting data into stream using PutRecord API...") for i := 0; i < 50; i++ { - publishRecord(t, kc) + publishRecord(t, kc, nil) } t.Log("Done putting data into stream using PutRecord API.") @@ -75,13 +146,17 @@ func publishSomeData(t *testing.T, kc kinesisiface.KinesisAPI) { } // publishRecord to put a record into Kinesis stream using PutRecord API. -func publishRecord(t *testing.T, kc kinesisiface.KinesisAPI) { - // Use random string as partition key to ensure even distribution across shards - _, err := kc.PutRecord(&kinesis.PutRecordInput{ +func publishRecord(t *testing.T, kc kinesisiface.KinesisAPI, hashKey *string) { + input := &kinesis.PutRecordInput{ Data: []byte(specstr), StreamName: aws.String(streamName), PartitionKey: aws.String(utils.RandStringBytesMaskImpr(10)), - }) + } + if hashKey != nil { + input.ExplicitHashKey = hashKey + } + // Use random string as partition key to ensure even distribution across shards + _, err := kc.PutRecord(input) if err != nil { t.Errorf("Error in PutRecord. %+v", err) @@ -94,10 +169,11 @@ func publishRecords(t *testing.T, kc kinesisiface.KinesisAPI) { records := make([]*kinesis.PutRecordsRequestEntry, 5) for i := 0; i < 5; i++ { - records[i] = &kinesis.PutRecordsRequestEntry{ + record := &kinesis.PutRecordsRequestEntry{ Data: []byte(specstr), PartitionKey: aws.String(utils.RandStringBytesMaskImpr(10)), } + records[i] = record } _, err := kc.PutRecords(&kinesis.PutRecordsInput{ diff --git a/test/worker_custom_test.go b/test/worker_custom_test.go index ef48491..19a6fb7 100644 --- a/test/worker_custom_test.go +++ b/test/worker_custom_test.go @@ -37,7 +37,7 @@ import ( ) func TestWorkerInjectCheckpointer(t *testing.T) { - kclConfig := cfg.NewKinesisClientLibConfig("appName", streamName, regionName, workerID). + kclConfig := cfg.NewKinesisClientLibConfig(appName, streamName, regionName, workerID). WithInitialPositionInStream(cfg.LATEST). WithMaxRecords(10). WithMaxLeasesForWorker(1). @@ -52,6 +52,12 @@ func TestWorkerInjectCheckpointer(t *testing.T) { // configure cloudwatch as metrics system kclConfig.WithMonitoringService(getMetricsConfig(kclConfig, metricsSystem)) + // Put some data into stream. + kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials) + // publishSomeData(t, kc) + stop := continuouslyPublishSomeData(t, kc) + defer stop() + // custom checkpointer or a mock checkpointer. checkpointer := chk.NewDynamoCheckpoint(kclConfig) @@ -62,12 +68,8 @@ func TestWorkerInjectCheckpointer(t *testing.T) { err := worker.Start() assert.Nil(t, err) - // Put some data into stream. - kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials) - publishSomeData(t, kc) - // wait a few seconds before shutdown processing - time.Sleep(10 * time.Second) + time.Sleep(30 * time.Second) worker.Shutdown() // verify the checkpointer after graceful shutdown @@ -86,7 +88,7 @@ func TestWorkerInjectCheckpointer(t *testing.T) { } func TestWorkerInjectKinesis(t *testing.T) { - kclConfig := cfg.NewKinesisClientLibConfig("appName", streamName, regionName, workerID). + kclConfig := cfg.NewKinesisClientLibConfig(appName, streamName, regionName, workerID). WithInitialPositionInStream(cfg.LATEST). WithMaxRecords(10). WithMaxLeasesForWorker(1). @@ -109,6 +111,11 @@ func TestWorkerInjectKinesis(t *testing.T) { assert.Nil(t, err) kc := kinesis.New(s) + // Put some data into stream. + // publishSomeData(t, kc) + stop := continuouslyPublishSomeData(t, kc) + defer stop() + // Inject a custom checkpointer into the worker. worker := wk.NewWorker(recordProcessorFactory(t), kclConfig). WithKinesis(kc) @@ -116,16 +123,13 @@ func TestWorkerInjectKinesis(t *testing.T) { err = worker.Start() assert.Nil(t, err) - // Put some data into stream. - publishSomeData(t, kc) - // wait a few seconds before shutdown processing - time.Sleep(10 * time.Second) + time.Sleep(30 * time.Second) worker.Shutdown() } func TestWorkerInjectKinesisAndCheckpointer(t *testing.T) { - kclConfig := cfg.NewKinesisClientLibConfig("appName", streamName, regionName, workerID). + kclConfig := cfg.NewKinesisClientLibConfig(appName, streamName, regionName, workerID). WithInitialPositionInStream(cfg.LATEST). WithMaxRecords(10). WithMaxLeasesForWorker(1). @@ -148,6 +152,11 @@ func TestWorkerInjectKinesisAndCheckpointer(t *testing.T) { assert.Nil(t, err) kc := kinesis.New(s) + // Put some data into stream. + // publishSomeData(t, kc) + stop := continuouslyPublishSomeData(t, kc) + defer stop() + // custom checkpointer or a mock checkpointer. checkpointer := chk.NewDynamoCheckpoint(kclConfig) @@ -159,10 +168,7 @@ func TestWorkerInjectKinesisAndCheckpointer(t *testing.T) { err = worker.Start() assert.Nil(t, err) - // Put some data into stream. - publishSomeData(t, kc) - // wait a few seconds before shutdown processing - time.Sleep(10 * time.Second) + time.Sleep(30 * time.Second) worker.Shutdown() } diff --git a/test/worker_lease_stealing_test.go b/test/worker_lease_stealing_test.go new file mode 100644 index 0000000..c35974c --- /dev/null +++ b/test/worker_lease_stealing_test.go @@ -0,0 +1,127 @@ +package test + +import ( + "testing" + + chk "github.com/vmware/vmware-go-kcl/clientlibrary/checkpoint" + cfg "github.com/vmware/vmware-go-kcl/clientlibrary/config" + wk "github.com/vmware/vmware-go-kcl/clientlibrary/worker" + "github.com/vmware/vmware-go-kcl/logger" +) + +func TestLeaseStealing(t *testing.T) { + config := &TestClusterConfig{ + numShards: 4, + numWorkers: 2, + appName: appName, + streamName: streamName, + regionName: regionName, + workerIDTemplate: workerID + "-%v", + } + test := NewLeaseStealingTest(t, config, newLeaseStealingWorkerFactory(t)) + test.Run(LeaseStealingAssertions{ + expectedLeasesForIntialWorker: config.numShards, + expectedLeasesPerWorker: config.numShards / config.numWorkers, + }) +} + +type leaseStealingWorkerFactory struct { + t *testing.T +} + +func newLeaseStealingWorkerFactory(t *testing.T) *leaseStealingWorkerFactory { + return &leaseStealingWorkerFactory{t} +} + +func (wf *leaseStealingWorkerFactory) CreateKCLConfig(workerID string, config *TestClusterConfig) *cfg.KinesisClientLibConfiguration { + log := logger.NewLogrusLoggerWithConfig(logger.Configuration{ + EnableConsole: true, + ConsoleLevel: logger.Error, + ConsoleJSONFormat: false, + EnableFile: true, + FileLevel: logger.Info, + FileJSONFormat: true, + Filename: "log.log", + }) + + log.WithFields(logger.Fields{"worker": workerID}) + + return cfg.NewKinesisClientLibConfig(config.appName, config.streamName, config.regionName, workerID). + WithInitialPositionInStream(cfg.LATEST). + WithMaxRecords(10). + WithShardSyncIntervalMillis(5000). + WithFailoverTimeMillis(10000). + WithLeaseStealing(true). + WithLogger(log) +} + +func (wf *leaseStealingWorkerFactory) CreateWorker(workerID string, kclConfig *cfg.KinesisClientLibConfiguration) *wk.Worker { + worker := wk.NewWorker(recordProcessorFactory(wf.t), kclConfig) + return worker +} + +func TestLeaseStealingInjectCheckpointer(t *testing.T) { + config := &TestClusterConfig{ + numShards: 4, + numWorkers: 2, + appName: appName, + streamName: streamName, + regionName: regionName, + workerIDTemplate: workerID + "-%v", + } + test := NewLeaseStealingTest(t, config, newleaseStealingWorkerFactoryCustomChk(t)) + test.Run(LeaseStealingAssertions{ + expectedLeasesForIntialWorker: config.numShards, + expectedLeasesPerWorker: config.numShards / config.numWorkers, + }) +} + +type leaseStealingWorkerFactoryCustom struct { + *leaseStealingWorkerFactory +} + +func newleaseStealingWorkerFactoryCustomChk(t *testing.T) *leaseStealingWorkerFactoryCustom { + return &leaseStealingWorkerFactoryCustom{ + newLeaseStealingWorkerFactory(t), + } +} + +func (wfc *leaseStealingWorkerFactoryCustom) CreateWorker(workerID string, kclConfig *cfg.KinesisClientLibConfiguration) *wk.Worker { + worker := wfc.leaseStealingWorkerFactory.CreateWorker(workerID, kclConfig) + checkpointer := chk.NewDynamoCheckpoint(kclConfig) + return worker.WithCheckpointer(checkpointer) +} + +func TestLeaseStealingWithMaxLeasesForWorker(t *testing.T) { + config := &TestClusterConfig{ + numShards: 4, + numWorkers: 2, + appName: appName, + streamName: streamName, + regionName: regionName, + workerIDTemplate: workerID + "-%v", + } + test := NewLeaseStealingTest(t, config, newleaseStealingWorkerFactoryMaxLeases(t, config.numShards-1)) + test.Run(LeaseStealingAssertions{ + expectedLeasesForIntialWorker: config.numShards - 1, + expectedLeasesPerWorker: 2, + }) +} + +type leaseStealingWorkerFactoryMaxLeases struct { + maxLeases int + *leaseStealingWorkerFactory +} + +func newleaseStealingWorkerFactoryMaxLeases(t *testing.T, maxLeases int) *leaseStealingWorkerFactoryMaxLeases { + return &leaseStealingWorkerFactoryMaxLeases{ + maxLeases, + newLeaseStealingWorkerFactory(t), + } +} + +func (wfm *leaseStealingWorkerFactoryMaxLeases) CreateKCLConfig(workerID string, config *TestClusterConfig) *cfg.KinesisClientLibConfiguration { + kclConfig := wfm.leaseStealingWorkerFactory.CreateKCLConfig(workerID, config) + kclConfig.WithMaxLeasesForWorker(wfm.maxLeases) + return kclConfig +} diff --git a/test/worker_test.go b/test/worker_test.go index b9f9a32..a445a59 100644 --- a/test/worker_test.go +++ b/test/worker_test.go @@ -60,7 +60,7 @@ func TestWorker(t *testing.T) { // In order to have precise control over logging. Use logger with config config := logger.Configuration{ EnableConsole: true, - ConsoleLevel: logger.Debug, + ConsoleLevel: logger.Error, ConsoleJSONFormat: false, EnableFile: true, FileLevel: logger.Info, @@ -269,8 +269,13 @@ func runTest(kclConfig *cfg.KinesisClientLibConfiguration, triggersig bool, t *t // configure cloudwatch as metrics system kclConfig.WithMonitoringService(getMetricsConfig(kclConfig, metricsSystem)) - worker := wk.NewWorker(recordProcessorFactory(t), kclConfig) + // Put some data into stream. + kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials) + // publishSomeData(t, kc) + stop := continuouslyPublishSomeData(t, kc) + defer stop() + worker := wk.NewWorker(recordProcessorFactory(t), kclConfig) err := worker.Start() assert.Nil(t, err) @@ -286,10 +291,6 @@ func runTest(kclConfig *cfg.KinesisClientLibConfiguration, triggersig bool, t *t //os.Exit(0) }() - // Put some data into stream. - kc := NewKinesisClient(t, regionName, kclConfig.KinesisEndpoint, kclConfig.KinesisCredentials) - publishSomeData(t, kc) - if triggersig { t.Log("Trigger signal SIGINT") p, _ := os.FindProcess(os.Getpid()) @@ -297,7 +298,7 @@ func runTest(kclConfig *cfg.KinesisClientLibConfiguration, triggersig bool, t *t } // wait a few seconds before shutdown processing - time.Sleep(10 * time.Second) + time.Sleep(30 * time.Second) if metricsSystem == "prometheus" { res, err := http.Get("http://localhost:8080/metrics")