Skip to content

Commit

Permalink
fix executor and improve tests for scatter_con
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Dec 20, 2024
1 parent 66bc6eb commit aebd8e4
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 63 deletions.
16 changes: 5 additions & 11 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,24 +262,18 @@ func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn
}

type streaminResultReceiver struct {
mu sync.Mutex
stmtType sqlparser.StatementType
rowsAffected uint64
rowsReturned int
insertID uint64
insertIDChanged bool
callback func(*sqltypes.Result) error
mu sync.Mutex
stmtType sqlparser.StatementType
rowsAffected uint64
rowsReturned int
callback func(*sqltypes.Result) error
}

func (s *streaminResultReceiver) storeResultStats(typ sqlparser.StatementType, qr *sqltypes.Result) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rowsAffected += qr.RowsAffected
s.rowsReturned += len(qr.Rows)
if qr.InsertIDUpdated() {
s.insertID = qr.InsertID
}
s.insertIDChanged = s.insertIDChanged || qr.InsertIDUpdated()
s.stmtType = typ
return s.callback(qr)
}
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/scatter_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ func (stc *ScatterConn) StreamExecuteMulti(
}
return callback(reply)
}

if session.Options != nil {
session.Options.FetchLastInsertId = fetchLastInsertID
}

allErrors := stc.multiGoTransaction(
ctx,
"StreamExecute",
Expand Down
160 changes: 108 additions & 52 deletions go/vt/vtgate/scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,23 @@ import (
"fmt"
"testing"

"vitess.io/vitess/go/vt/log"
econtext "vitess.io/vitess/go/vt/vtgate/executorcontext"

"vitess.io/vitess/go/mysql/sqlerror"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"

"github.com/aws/smithy-go/ptr"
"github.com/stretchr/testify/assert"

"vitess.io/vitess/go/vt/key"

"vitess.io/vitess/go/test/utils"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/sqlerror"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
"vitess.io/vitess/go/vt/discovery"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/srvtopo"
"vitess.io/vitess/go/vt/vterrors"
econtext "vitess.io/vitess/go/vt/vtgate/executorcontext"
)

// This file uses the sandbox_test framework.
Expand Down Expand Up @@ -110,55 +106,115 @@ func TestExecuteFailOnAutocommit(t *testing.T) {
}

func TestFetchLastInsertIDResets(t *testing.T) {
ctx := utils.LeakCheckContext(t)

// This test verifies that the FetchLastInsertID flag is reset after a call to ExecuteMultiShard.
ks := "TestFetchLastInsertIDResets"
createSandbox(ks)
hc := discovery.NewFakeHealthCheck(nil)
sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc0 := hc.AddTestTablet("aa", "0", 1, ks, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
sbc1 := hc.AddTestTablet("aa", "1", 1, ks, "1", topodatapb.TabletType_PRIMARY, true, 1, nil)

rss := []*srvtopo.ResolvedShard{{
Target: &querypb.Target{
Keyspace: ks,
Shard: "0",
TabletType: topodatapb.TabletType_PRIMARY,
tests := []struct {
name string
initialSessionOpts *querypb.ExecuteOptions
fetchLastInsertID bool
expectSessionNil bool
expectFetchLastID *bool // nil means checkLastOptionNil, otherwise checkLastOption(*bool)
}{
{
name: "no session options, fetchLastInsertID = false",
initialSessionOpts: nil,
fetchLastInsertID: false,
expectSessionNil: true,
expectFetchLastID: nil,
},
Gateway: sbc0,
}, {
Target: &querypb.Target{
Keyspace: ks,
Shard: "1",
TabletType: topodatapb.TabletType_PRIMARY,
{
name: "no session options, fetchLastInsertID = true",
initialSessionOpts: nil,
fetchLastInsertID: true,
expectSessionNil: true,

expectFetchLastID: ptr.Bool(true),
},
Gateway: sbc1,
}}
queries := []*querypb.BoundQuery{{
// This will fail to go to shard. It will be rejected at vtgate.
Sql: "query1",
BindVariables: map[string]*querypb.BindVariable{
"bv0": sqltypes.Int64BindVariable(0),
{
name: "session options set, fetchLastInsertID = false",
initialSessionOpts: &querypb.ExecuteOptions{},
fetchLastInsertID: false,
expectSessionNil: false,
expectFetchLastID: ptr.Bool(false),
},
}, {
// This will go to shard.
Sql: "query2",
BindVariables: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(1),
{
name: "session options set, fetchLastInsertID = true",
initialSessionOpts: &querypb.ExecuteOptions{},
fetchLastInsertID: true,
expectSessionNil: false,
expectFetchLastID: ptr.Bool(true),
},
}}
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := utils.LeakCheckContext(t)

createSandbox(ks)
hc := discovery.NewFakeHealthCheck(nil)
sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc0 := hc.AddTestTablet("aa", "0", 1, ks, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
sbc1 := hc.AddTestTablet("aa", "1", 1, ks, "1", topodatapb.TabletType_PRIMARY, true, 1, nil)

rss := []*srvtopo.ResolvedShard{{
Target: &querypb.Target{
Keyspace: ks,
Shard: "0",
TabletType: topodatapb.TabletType_PRIMARY,
},
Gateway: sbc0,
}, {
Target: &querypb.Target{
Keyspace: ks,
Shard: "1",
TabletType: topodatapb.TabletType_PRIMARY,
},
Gateway: sbc1,
}}
queries := []*querypb.BoundQuery{{
Sql: "query1",
BindVariables: map[string]*querypb.BindVariable{
"bv0": sqltypes.Int64BindVariable(0),
},
}, {
Sql: "query2",
BindVariables: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(1),
},
}}

session := econtext.NewSafeSession(nil)
session.Options = tt.initialSessionOpts

checkLastOption := func(expected bool) {
require.Equal(t, 1, len(sbc0.Options))
options := sbc0.Options[0]
assert.Equal(t, options.FetchLastInsertId, expected)
sbc0.Options = nil
}
checkLastOptionNil := func() {
require.Equal(t, 1, len(sbc0.Options))
assert.Nil(t, sbc0.Options[0])
sbc0.Options = nil
}

session := econtext.NewSafeSession(&vtgatepb.Session{Options: &querypb.ExecuteOptions{}})
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, tt.fetchLastInsertID)
require.NoError(t, vterrors.Aggregate(errs))

fetchLastInsertID := true
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, fetchLastInsertID)
require.NoError(t, vterrors.Aggregate(errs))
assert.True(t, session.Options.FetchLastInsertId)
if tt.expectSessionNil {
assert.Nil(t, session.Options)
} else {
assert.NotNil(t, session.Options)
assert.Equal(t, tt.fetchLastInsertID, session.Options.FetchLastInsertId)
}

fetchLastInsertID = false
_, errs = sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, fetchLastInsertID)
require.NoError(t, vterrors.Aggregate(errs))
assert.False(t, session.Options.FetchLastInsertId)
if tt.expectFetchLastID == nil {
checkLastOptionNil()
} else {
checkLastOption(*tt.expectFetchLastID)
}
})
}
}

func TestExecutePanic(t *testing.T) {
Expand Down

0 comments on commit aebd8e4

Please sign in to comment.