diff --git a/go/vt/vtgate/engine/dml_with_input.go b/go/vt/vtgate/engine/dml_with_input.go index 0974f753cef..e0eb3b03592 100644 --- a/go/vt/vtgate/engine/dml_with_input.go +++ b/go/vt/vtgate/engine/dml_with_input.go @@ -146,7 +146,7 @@ func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[ if res == nil { res = qr } else { - res.RowsAffected += res.RowsAffected + res.RowsAffected += qr.RowsAffected } } return res, nil diff --git a/go/vt/vtgate/engine/dml_with_input_test.go b/go/vt/vtgate/engine/dml_with_input_test.go index b41dc9e148c..6fcf2040dfc 100644 --- a/go/vt/vtgate/engine/dml_with_input_test.go +++ b/go/vt/vtgate/engine/dml_with_input_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -180,3 +181,78 @@ func TestDeleteWithMultiTarget(t *testing.T) { `ExecuteMultiShard ks.-20: dummy_delete_2 {dml_vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x03100\x89\x02\x011"} values:{type:TUPLE value:"\x89\x02\x03100\x89\x02\x012"} values:{type:TUPLE value:"\x89\x02\x03200\x89\x02\x013"}} true true`, }) } + +// TestUpdateWithInputNonLiteral test the case where the column updated have non literal update. +// Therefore, update query should be executed for each row in the input result. +// This also validates the output rows affected. +func TestUpdateWithInputNonLiteral(t *testing.T) { + input := &fakePrimitive{results: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|val", "int64|varchar|int64"), "1|a|100", "2|b|200", "3|c|300"), + }} + + dml := &DMLWithInput{ + Input: input, + DMLs: []Primitive{&Update{ + DML: &DML{ + RoutingParameters: &RoutingParameters{ + Opcode: Scatter, + Keyspace: &vindexes.Keyspace{ + Name: "ks", + Sharded: true, + }, + }, + Query: "dummy_update", + }, + }}, + OutputCols: [][]int{{1, 0}}, + BVList: []map[string]int{ + {"bv1": 2}, + }, + } + + vc := newDMLTestVCursor("-20", "20-") + vc.results = []*sqltypes.Result{ + {RowsAffected: 1}, {RowsAffected: 1}, {RowsAffected: 1}, + } + qr, err := dml.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ` + + `ks.-20: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} ` + + `ks.20-: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} true false`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ` + + `ks.-20: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} ` + + `ks.20-: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} true false`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ` + + `ks.-20: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} ` + + `ks.20-: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} true false`, + }) + assert.EqualValues(t, 3, qr.RowsAffected) + + vc.Rewind() + input.rewind() + err = dml.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false, + func(result *sqltypes.Result) error { + qr = result + return nil + }) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ` + + `ks.-20: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} ` + + `ks.20-: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} true false`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ` + + `ks.-20: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} ` + + `ks.20-: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} true false`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ` + + `ks.-20: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} ` + + `ks.20-: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} true false`, + }) + assert.EqualValues(t, 3, qr.RowsAffected) +} diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 08a40c0d835..1ba0abfa2ef 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -803,7 +803,7 @@ func (f *loggingVCursor) nextResult() (*sqltypes.Result, error) { if r == nil { return &sqltypes.Result{}, f.resultErr } - return r, nil + return r.Copy(), nil } func (f *loggingVCursor) CanUseSetVar() bool {