diff --git a/aws-v1/client/client.go b/aws-v1/client/client.go index c0ea1d4..11d1ebc 100644 --- a/aws-v1/client/client.go +++ b/aws-v1/client/client.go @@ -653,6 +653,73 @@ func (fd *Client) TransactWriteItemsWithContext(ctx aws.Context, input *dynamodb return fd.TransactWriteItems(input) } +// BatchGetItem mock response for dynamodb +func (fd *Client) BatchGetItem(input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) { + if fd.forceFailureErr != nil { + return nil, fd.forceFailureErr + } + + responses := make(map[string][]map[string]*dynamodb.AttributeValue, len(input.RequestItems)) + unprocessed := make(map[string]*dynamodb.KeysAndAttributes, len(input.RequestItems)) + + for tableName, reqs := range input.RequestItems { + unprocessedKeys := make([]map[string]*dynamodb.AttributeValue, 0, len(reqs.Keys)) + responses[tableName] = make([]map[string]*dynamodb.AttributeValue, 0, len(reqs.Keys)) + + for _, req := range reqs.Keys { + getInput := &dynamodb.GetItemInput{ + TableName: aws.String(tableName), + Key: req, + ConsistentRead: reqs.ConsistentRead, + AttributesToGet: reqs.AttributesToGet, + ExpressionAttributeNames: reqs.ExpressionAttributeNames, + ProjectionExpression: reqs.ProjectionExpression, + } + + item, err := executeGetRequest(fd, getInput) + if err != nil { + unprocessedKeys = append(unprocessedKeys, req) + + continue + } + + responses[tableName] = append(responses[tableName], item) + } + + if len(unprocessedKeys) > 0 { + unprocessed[tableName] = reqs + + tableUnprocessedKeys := unprocessed[tableName] + tableUnprocessedKeys.Keys = unprocessedKeys + + unprocessed[tableName] = tableUnprocessedKeys + } + } + + return &dynamodb.BatchGetItemOutput{ + Responses: responses, + UnprocessedKeys: unprocessed, + }, nil +} + +// BatchGetItemWithContext mock response for dynamodb +func (fd *Client) BatchGetItemWithContext(ctx aws.Context, input *dynamodb.BatchGetItemInput, opts ...request.Option) (*dynamodb.BatchGetItemOutput, error) { + return fd.BatchGetItem(input) +} + +func executeGetRequest(fd *Client, getInput *dynamodb.GetItemInput) (map[string]*dynamodb.AttributeValue, error) { + response, err := fd.GetItem(getInput) + if err != nil { + return nil, err + } + + if len(response.Item) == 0 { + return nil, ErrResourceNotFoundException + } + + return response.Item, nil +} + func (fd *Client) getTable(tableName string) (*core.Table, error) { table, ok := fd.tables[tableName] if !ok { diff --git a/aws-v1/client/client_test.go b/aws-v1/client/client_test.go index 387d872..b01487a 100644 --- a/aws-v1/client/client_test.go +++ b/aws-v1/client/client_test.go @@ -1856,6 +1856,82 @@ func TestBatchWriteItemWithFailingDatabase(t *testing.T) { c.NotEmpty(output.UnprocessedItems) } +func TestBatchGetItemWithContext(t *testing.T) { + c := require.New(t) + client := setupClient(tableName) + + err := ensurePokemonTable(client) + c.NoError(err) + + item, err := dynamodbattribute.MarshalMap(pokemon{ + ID: "001", + Type: "grass", + Name: "Bulbasaur", + }) + c.NoError(err) + + input := &dynamodb.PutItemInput{ + Item: item, + TableName: aws.String(tableName), + } + + _, err = client.PutItemWithContext(context.Background(), input) + c.NoError(err) + + item, err = dynamodbattribute.MarshalMap(pokemon{ + ID: "002", + Type: "fire", + Name: "Charmander", + }) + c.NoError(err) + + input = &dynamodb.PutItemInput{ + Item: item, + TableName: aws.String(tableName), + } + + _, err = client.PutItemWithContext(context.Background(), input) + c.NoError(err) + + getInput := &dynamodb.BatchGetItemInput{ + RequestItems: map[string]*dynamodb.KeysAndAttributes{ + tableName: { + Keys: []map[string]*dynamodb.AttributeValue{ + { + "id": {S: aws.String("001")}, + }, + { + "id": {S: aws.String("002")}, + }, + { + "t1": {S: aws.String("003")}, + }, + { + "id": {S: aws.String("004")}, + }, + }, + }, + }, + } + + ActiveForceFailure(client) + + out, err := client.BatchGetItemWithContext(context.Background(), getInput) + c.Nil(out) + c.EqualError(err, ErrForcedFailure.Error()) + + DeactiveForceFailure(client) + + out, err = client.BatchGetItemWithContext(context.Background(), getInput) + c.NoError(err) + c.Len(out.Responses[tableName], 2) + c.Equal("001", aws.StringValue(out.Responses[tableName][0]["id"].S)) + c.Equal("002", aws.StringValue(out.Responses[tableName][1]["id"].S)) + c.Len(out.UnprocessedKeys[tableName].Keys, 2) + c.Equal("003", aws.StringValue(out.UnprocessedKeys[tableName].Keys[0]["t1"].S)) + c.Equal("004", aws.StringValue(out.UnprocessedKeys[tableName].Keys[1]["id"].S)) +} + func TestTransactWriteItemsWithContext(t *testing.T) { c := require.New(t) client := NewClient()