diff --git a/tests/transact_items_test.go b/tests/transact_items_test.go index 4ddfc4b..7d1c46c 100644 --- a/tests/transact_items_test.go +++ b/tests/transact_items_test.go @@ -23,82 +23,94 @@ func TestTransactItems(t *testing.T) { table := prepareTable(t, endoint, "transcact_item_test") testCases := []struct { - title string - condition string - keys map[string]types.AttributeValue - opts []dynago.QueryOptions - source1 []Terminal - source2 []Terminal + title string + condition string + keys map[string]types.AttributeValue + opts []dynago.QueryOptions + //items to be added + newItems []Terminal + operations []types.TransactWriteItem + //items expected to exist in table after transaction operation expected []Terminal expectedErr error }{{ - title: "assign terminal", + title: "assign terminal - only add a terminal", condition: "pk = :pk", keys: map[string]types.AttributeValue{ ":pk": &types.AttributeValueMemberS{Value: "terminal1"}, }, - source1: []Terminal{ - { + newItems: []Terminal{}, + operations: []types.TransactWriteItem{ + table.WithPutItem("terminal1", "merchant2", Terminal{ Id: "1", Pk: "terminal1", Sk: "merchant1", - }, + }), }, - source2: []Terminal{ + expected: []Terminal{ { Id: "1", Pk: "terminal1", - Sk: "merchant2", + Sk: "merchant1", }, }, - expected: []Terminal{ - { + }, + { + title: "assign terminal - delete existing and update with new", + condition: "pk = :pk", + keys: map[string]types.AttributeValue{ + ":pk": &types.AttributeValueMemberS{Value: "terminal1"}, + }, + newItems: []Terminal{{ Id: "1", Pk: "terminal1", Sk: "merchant2", + }}, + operations: []types.TransactWriteItem{ + table.WithDeleteItem("terminal1", "merchant1"), + table.WithPutItem("terminal1", "merchant2", Terminal{ + Id: "1", + Pk: "terminal1", + Sk: "merchant2", + }), + }, + expected: []Terminal{ + { + Id: "1", + Pk: "terminal1", + Sk: "merchant2", + }, }, }, - }, } for _, tc := range testCases { t.Run(tc.title, func(t *testing.T) { t.Helper() ctx := context.TODO() // Create Item - if len(tc.source1) > 0 { - items := make([]*dynago.TransactPutItemsInput, 0, len(tc.source1)) - for _, item := range tc.source1 { + if len(tc.newItems) > 0 { + items := make([]*dynago.TransactPutItemsInput, 0, len(tc.newItems)) + for _, item := range tc.newItems { items = append(items, &dynago.TransactPutItemsInput{ dynago.StringValue(item.Pk), dynago.StringValue(item.Sk), item, }) } err := table.TransactPutItems(ctx, items) if err != nil { - t.Fatalf("prepare table failed; got %s", err) - } - } - // Update Item - items := make([]types.TransactWriteItem, 0, len(tc.source1)+len(tc.source2)) - if len(tc.source1) > 0 { - for _, item := range tc.source1 { - items = append(items, table.WithDeleteItem(ctx, item.Pk, - item.Sk)) + t.Fatalf("transaction put items failed; got %s", err) } } - if len(tc.source2) > 0 { - for _, item := range tc.source2 { - items = append(items, table.WithPutItem(ctx, item.Pk, - item.Sk, - item)) + //perform operations + if len(tc.operations) > 0 { + err := table.TransactItems(ctx, tc.operations) + if err != nil { + t.Fatalf("error occurred %s", err) } - } - err := table.TransactItems(ctx, items) - if err != nil { - t.Fatalf("error occurred %s", err) + } var out []Terminal - _, err = table.Query(ctx, tc.condition, tc.keys, &out) + _, err := table.Query(ctx, tc.condition, tc.keys, &out) if tc.expectedErr != nil { if err == nil { t.Fatalf("expected query to fail with %s", tc.expectedErr) diff --git a/transaction_items.go b/transaction_items.go index 307de24..79ded31 100644 --- a/transaction_items.go +++ b/transaction_items.go @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) -func (t *Client) WithDeleteItem(ctx context.Context, pk string, sk string) types.TransactWriteItem { +func (t *Client) WithDeleteItem(pk string, sk string) types.TransactWriteItem { return types.TransactWriteItem{ Delete: &types.Delete{ TableName: &t.TableName, @@ -22,7 +22,7 @@ func (t *Client) WithDeleteItem(ctx context.Context, pk string, sk string) types } -func (t *Client) WithPutItem(ctx context.Context, pk string, sk string, item interface{}) types.TransactWriteItem { +func (t *Client) WithPutItem(pk string, sk string, item interface{}) types.TransactWriteItem { av, err := attributevalue.MarshalMap(item) if err != nil { log.Println("Failed to Marshal item" + err.Error())