Skip to content

Commit

Permalink
[release-20.0-rc] fix: rows affected count for multi table update for…
Browse files Browse the repository at this point in the history
… non-literal column value (#16181) (#16182)

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com>
  • Loading branch information
vitess-bot[bot] authored Jun 14, 2024
1 parent c3c77e1 commit 1f308de
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[
if res == nil {
res = qr
} else {
res.RowsAffected += res.RowsAffected
res.RowsAffected += qr.RowsAffected
}
}
return res, nil
Expand Down
76 changes: 76 additions & 0 deletions go/vt/vtgate/engine/dml_with_input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -180,3 +181,78 @@ func TestDeleteWithMultiTarget(t *testing.T) {
`ExecuteMultiShard ks.-20: dummy_delete_2 {dml_vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x03100\x89\x02\x011"} values:{type:TUPLE value:"\x89\x02\x03100\x89\x02\x012"} values:{type:TUPLE value:"\x89\x02\x03200\x89\x02\x013"}} true true`,
})
}

// TestUpdateWithInputNonLiteral test the case where the column updated have non literal update.
// Therefore, update query should be executed for each row in the input result.
// This also validates the output rows affected.
func TestUpdateWithInputNonLiteral(t *testing.T) {
input := &fakePrimitive{results: []*sqltypes.Result{
sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|val", "int64|varchar|int64"), "1|a|100", "2|b|200", "3|c|300"),
}}

dml := &DMLWithInput{
Input: input,
DMLs: []Primitive{&Update{
DML: &DML{
RoutingParameters: &RoutingParameters{
Opcode: Scatter,
Keyspace: &vindexes.Keyspace{
Name: "ks",
Sharded: true,
},
},
Query: "dummy_update",
},
}},
OutputCols: [][]int{{1, 0}},
BVList: []map[string]int{
{"bv1": 2},
},
}

vc := newDMLTestVCursor("-20", "20-")
vc.results = []*sqltypes.Result{
{RowsAffected: 1}, {RowsAffected: 1}, {RowsAffected: 1},
}
qr, err := dml.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false)
require.NoError(t, err)
vc.ExpectLog(t, []string{
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ` +
`ks.-20: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} ` +
`ks.20-: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} true false`,
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ` +
`ks.-20: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} ` +
`ks.20-: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} true false`,
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ` +
`ks.-20: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} ` +
`ks.20-: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} true false`,
})
assert.EqualValues(t, 3, qr.RowsAffected)

vc.Rewind()
input.rewind()
err = dml.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false,
func(result *sqltypes.Result) error {
qr = result
return nil
})
require.NoError(t, err)
vc.ExpectLog(t, []string{
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ` +
`ks.-20: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} ` +
`ks.20-: dummy_update {bv1: type:INT64 value:"100" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01a\x89\x02\x011"}} true false`,
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ` +
`ks.-20: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} ` +
`ks.20-: dummy_update {bv1: type:INT64 value:"200" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01b\x89\x02\x012"}} true false`,
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ` +
`ks.-20: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} ` +
`ks.20-: dummy_update {bv1: type:INT64 value:"300" dml_vals: type:TUPLE values:{type:TUPLE value:"\x950\x01c\x89\x02\x013"}} true false`,
})
assert.EqualValues(t, 3, qr.RowsAffected)
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ func (f *loggingVCursor) nextResult() (*sqltypes.Result, error) {
if r == nil {
return &sqltypes.Result{}, f.resultErr
}
return r, nil
return r.Copy(), nil
}

func (f *loggingVCursor) CanUseSetVar() bool {
Expand Down

0 comments on commit 1f308de

Please sign in to comment.