diff --git a/go/test/vschemawrapper/vschema_wrapper.go b/go/test/vschemawrapper/vschema_wrapper.go index b362a8b7408..3f9f072afc6 100644 --- a/go/test/vschemawrapper/vschema_wrapper.go +++ b/go/test/vschemawrapper/vschema_wrapper.go @@ -33,6 +33,7 @@ import ( "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -40,7 +41,10 @@ import ( var _ plancontext.VSchema = (*VSchemaWrapper)(nil) +// VSchemaWrapper is a wrapper around VSchema that implements the ContextVSchema interface. +// It is used in tests to provide a VSchema implementation. type VSchemaWrapper struct { + Vcursor *econtext.VCursorImpl V *vindexes.VSchema Keyspace *vindexes.Keyspace TabletType_ topodatapb.TabletType @@ -53,6 +57,30 @@ type VSchemaWrapper struct { Env *vtenv.Environment } +func NewVschemaWrapper( + env *vtenv.Environment, + vschema *vindexes.VSchema, + builder func(string, plancontext.VSchema, string) (*engine.Plan, error), +) (*VSchemaWrapper, error) { + ss := econtext.NewAutocommitSession(&vtgatepb.Session{}) + vcursor, err := econtext.NewVCursorImpl(ss, sqlparser.MarginComments{}, nil, nil, nil, vschema, nil, nil, nil, econtext.VCursorConfig{ + Collation: env.CollationEnv().DefaultConnectionCharset(), + DefaultTabletType: topodatapb.TabletType_PRIMARY, + SetVarEnabled: true, + }) + if err != nil { + return nil, err + } + return &VSchemaWrapper{ + Env: env, + V: vschema, + Vcursor: vcursor, + TestBuilder: builder, + TabletType_: topodatapb.TabletType_PRIMARY, + SysVarEnabled: true, + }, nil +} + func (vw *VSchemaWrapper) GetPrepareData(stmtName string) *vtgatepb.PrepareData { switch stmtName { case "prep_one_param": @@ -244,34 +272,7 @@ func (vw *VSchemaWrapper) FindView(tab sqlparser.TableName) sqlparser.SelectStat } func (vw *VSchemaWrapper) FindTableOrVindex(tab sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { - if tab.Qualifier.IsEmpty() && tab.Name.String() == "dual" { - ksName := vw.getActualKeyspace() - var ks *vindexes.Keyspace - if ksName == "" { - ks = vw.getfirstKeyspace() - ksName = ks.Name - } else { - ks = vw.V.Keyspaces[ksName].Keyspace - } - tbl := &vindexes.Table{ - Name: sqlparser.NewIdentifierCS("dual"), - Keyspace: ks, - Type: vindexes.TypeReference, - } - return tbl, nil, ksName, topodatapb.TabletType_PRIMARY, nil, nil - } - destKeyspace, destTabletType, destTarget, err := topoproto.ParseDestination(tab.Qualifier.String(), topodatapb.TabletType_PRIMARY) - if err != nil { - return nil, nil, destKeyspace, destTabletType, destTarget, err - } - if destKeyspace == "" { - destKeyspace = vw.getActualKeyspace() - } - table, vindex, err := vw.V.FindTableOrVindex(destKeyspace, tab.Name.String(), topodatapb.TabletType_PRIMARY) - if err != nil { - return nil, nil, destKeyspace, destTabletType, destTarget, err - } - return table, vindex, destKeyspace, destTabletType, destTarget, nil + return vw.Vcursor.FindTableOrVindex(tab) } func (vw *VSchemaWrapper) getfirstKeyspace() (ks *vindexes.Keyspace) { diff --git a/go/vt/vtexplain/vtexplain_vtgate.go b/go/vt/vtexplain/vtexplain_vtgate.go index d511e2d2ea0..f9ae8be3820 100644 --- a/go/vt/vtexplain/vtexplain_vtgate.go +++ b/go/vt/vtexplain/vtexplain_vtgate.go @@ -38,6 +38,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate" "vitess.io/vitess/go/vt/vtgate/engine" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/vtgate/vindexes" "vitess.io/vitess/go/vt/vttablet/queryservice" @@ -235,7 +236,7 @@ func (vte *VTExplain) vtgateExecute(sql string) ([]*engine.Plan, map[string]*Tab // This will ensure that the commit/rollback order is predictable. vte.sortShardSession() - _, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", vtgate.NewSafeSession(vte.vtgateSession), sql, nil) + _, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", econtext.NewSafeSession(vte.vtgateSession), sql, nil) if err != nil { for _, tc := range vte.explainTopo.TabletConns { tc.tabletQueries = nil diff --git a/go/vt/vtgate/autocommit_test.go b/go/vt/vtgate/autocommit_test.go index 1ba99c01ef2..2e65cefbabe 100644 --- a/go/vt/vtgate/autocommit_test.go +++ b/go/vt/vtgate/autocommit_test.go @@ -23,10 +23,10 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" - querypb "vitess.io/vitess/go/vt/proto/query" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" ) // This file contains tests for all the autocommit code paths @@ -382,7 +382,7 @@ func TestAutocommitTransactionStarted(t *testing.T) { // single shard query - no savepoint needed sql := "update `user` set a = 2 where id = 1" - _, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) + _, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) require.Len(t, sbc1.Queries, 1) require.Equal(t, sql, sbc1.Queries[0].Sql) @@ -394,7 +394,7 @@ func TestAutocommitTransactionStarted(t *testing.T) { // multi shard query - savepoint needed sql = "update `user` set a = 2 where id in (1, 4)" expectedSql := "update `user` set a = 2 where id in ::__vals" - _, err = executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) + _, err = executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) require.Len(t, sbc1.Queries, 2) require.Contains(t, sbc1.Queries[0].Sql, "savepoint") @@ -413,7 +413,7 @@ func TestAutocommitDirectTarget(t *testing.T) { } sql := "insert into `simple`(val) values ('val')" - _, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) + _, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) assertQueries(t, sbclookup, []*querypb.BoundQuery{{ @@ -434,7 +434,7 @@ func TestAutocommitDirectRangeTarget(t *testing.T) { } sql := "delete from sharded_user_msgs limit 1000" - _, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) + _, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) assertQueries(t, sbc1, []*querypb.BoundQuery{{ @@ -451,5 +451,5 @@ func autocommitExec(executor *Executor, sql string) (*sqltypes.Result, error) { TransactionMode: vtgatepb.TransactionMode_MULTI, } - return executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) + return executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index c56e076e2fd..e84ab7fbb21 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -30,6 +30,8 @@ import ( "github.com/spf13/pflag" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/acl" "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/mysql/capabilities" @@ -57,6 +59,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/evalengine" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/vtgate/planbuilder" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -67,7 +70,6 @@ import ( ) var ( - errNoKeyspace = vterrors.VT09005() defaultTabletType = topodatapb.TabletType_PRIMARY // TODO: @rafael - These two counters should be deprecated in favor of the ByTable ones in v17+. They are kept for now for backwards compatibility. @@ -111,7 +113,6 @@ type Executor struct { resolver *Resolver scatterConn *ScatterConn txConn *TxConn - pv plancontext.PlannerVersion mu sync.Mutex vschema *vindexes.VSchema @@ -121,8 +122,7 @@ type Executor struct { plans *PlanCache epoch atomic.Uint32 - normalize bool - warnShardedOnly bool + normalize bool vm *VSchemaManager schemaTracker SchemaInfo @@ -135,6 +135,8 @@ type Executor struct { warmingReadsPercent int warmingReadsChannel chan bool + + vConfig econtext.VCursorConfig } var executorOnce sync.Once @@ -175,15 +177,15 @@ func NewExecutor( scatterConn: resolver.scatterConn, txConn: resolver.scatterConn.txConn, normalize: normalize, - warnShardedOnly: warnOnShardedOnly, streamSize: streamSize, schemaTracker: schemaTracker, allowScatter: !noScatter, - pv: pv, plans: plans, warmingReadsPercent: warmingReadsPercent, warmingReadsChannel: make(chan bool, warmingReadsConcurrency), } + // setting the vcursor config. + e.initVConfig(warnOnShardedOnly, pv) vschemaacl.Init() // we subscribe to update from the VSchemaManager @@ -223,7 +225,7 @@ func NewExecutor( } // Execute executes a non-streaming query. -func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) { +func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) { span, ctx := trace.NewSpan(ctx, "executor.Execute") span.Annotate("method", method) trace.AnnotateSQL(span, sqlparser.Preview(sql)) @@ -286,7 +288,7 @@ func (e *Executor) StreamExecute( ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, - safeSession *SafeSession, + safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error, @@ -300,7 +302,7 @@ func (e *Executor) StreamExecute( srr := &streaminResultReceiver{callback: callback} var err error - resultHandler := func(ctx context.Context, plan *engine.Plan, vc *vcursorImpl, bindVars map[string]*querypb.BindVariable, execStart time.Time) error { + resultHandler := func(ctx context.Context, plan *engine.Plan, vc *econtext.VCursorImpl, bindVars map[string]*querypb.BindVariable, execStart time.Time) error { var seenResults atomic.Bool var resultMu sync.Mutex result := &sqltypes.Result{} @@ -368,7 +370,7 @@ func (e *Executor) StreamExecute( logStats.TablesUsed = plan.TablesUsed logStats.TabletType = vc.TabletType().String() logStats.ExecuteTime = time.Since(execStart) - logStats.ActiveKeyspace = vc.keyspace + logStats.ActiveKeyspace = vc.GetKeyspace() e.updateQueryCounts(plan.Instructions.RouteType(), plan.Instructions.GetKeyspaceName(), plan.Instructions.GetTableName(), int64(logStats.ShardQueries)) @@ -411,12 +413,12 @@ func canReturnRows(stmtType sqlparser.StatementType) bool { } } -func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType, rowsAffected, insertID uint64, rowsReturned int, err error) { +func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.StatementType, rowsAffected, insertID uint64, rowsReturned int, err error) { safeSession.RowCount = -1 if err != nil { return } - if !safeSession.foundRowsHandled { + if !safeSession.IsFoundRowsHandled() { safeSession.FoundRows = uint64(rowsReturned) } if insertID > 0 { @@ -430,11 +432,11 @@ func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType } } -func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) (sqlparser.StatementType, *sqltypes.Result, error) { +func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) (sqlparser.StatementType, *sqltypes.Result, error) { var err error var qr *sqltypes.Result var stmtType sqlparser.StatementType - err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, func(ctx context.Context, plan *engine.Plan, vc *vcursorImpl, bindVars map[string]*querypb.BindVariable, time time.Time) error { + err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, func(ctx context.Context, plan *engine.Plan, vc *econtext.VCursorImpl, bindVars map[string]*querypb.BindVariable, time time.Time) error { stmtType = plan.Type qr, err = e.executePlan(ctx, safeSession, plan, vc, bindVars, logStats, time) return err @@ -448,7 +450,7 @@ func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn } // addNeededBindVars adds bind vars that are needed by the plan -func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *SafeSession) error { +func (e *Executor) addNeededBindVars(vcursor *econtext.VCursorImpl, bindVarNeeds *sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *econtext.SafeSession) error { for _, funcName := range bindVarNeeds.NeedFunctionResult { switch funcName { case sqlparser.DBVarName: @@ -541,7 +543,7 @@ func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlpars } evalExpr, err := evalengine.Translate(expr, &evalengine.Config{ - Collation: vcursor.collation, + Collation: vcursor.ConnCollation(), Environment: e.env, SQLMode: evalengine.ParseSQLMode(vcursor.SQLMode()), }) @@ -552,7 +554,7 @@ func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlpars if err != nil { return err } - bindVars[key] = sqltypes.ValueBindVariable(evaluated.Value(vcursor.collation)) + bindVars[key] = sqltypes.ValueBindVariable(evaluated.Value(vcursor.ConnCollation())) } } } @@ -572,21 +574,21 @@ func (e *Executor) addNeededBindVars(vcursor *vcursorImpl, bindVarNeeds *sqlpars return nil } -func ifOptionsExist(session *SafeSession, f func(*querypb.ExecuteOptions)) { +func ifOptionsExist(session *econtext.SafeSession, f func(*querypb.ExecuteOptions)) { options := session.GetOptions() if options != nil { f(options) } } -func ifReadAfterWriteExist(session *SafeSession, f func(*vtgatepb.ReadAfterWrite)) { +func ifReadAfterWriteExist(session *econtext.SafeSession, f func(*vtgatepb.ReadAfterWrite)) { raw := session.ReadAfterWrite if raw != nil { f(raw) } } -func (e *Executor) handleBegin(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats, stmt sqlparser.Statement) (*sqltypes.Result, error) { +func (e *Executor) handleBegin(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats, stmt sqlparser.Statement) (*sqltypes.Result, error) { execStart := time.Now() logStats.PlanTime = execStart.Sub(logStats.StartTime) @@ -599,7 +601,7 @@ func (e *Executor) handleBegin(ctx context.Context, safeSession *SafeSession, lo return &sqltypes.Result{}, err } -func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) { +func (e *Executor) handleCommit(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) { execStart := time.Now() logStats.PlanTime = execStart.Sub(logStats.StartTime) logStats.ShardQueries = uint64(len(safeSession.ShardSessions)) @@ -611,11 +613,11 @@ func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, l } // Commit commits the existing transactions -func (e *Executor) Commit(ctx context.Context, safeSession *SafeSession) error { +func (e *Executor) Commit(ctx context.Context, safeSession *econtext.SafeSession) error { return e.txConn.Commit(ctx, safeSession) } -func (e *Executor) handleRollback(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) { +func (e *Executor) handleRollback(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats) (*sqltypes.Result, error) { execStart := time.Now() logStats.PlanTime = execStart.Sub(logStats.StartTime) logStats.ShardQueries = uint64(len(safeSession.ShardSessions)) @@ -625,7 +627,7 @@ func (e *Executor) handleRollback(ctx context.Context, safeSession *SafeSession, return &sqltypes.Result{}, err } -func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession, sql string, planType string, logStats *logstats.LogStats, nonTxResponse func(query string) (*sqltypes.Result, error), ignoreMaxMemoryRows bool) (*sqltypes.Result, error) { +func (e *Executor) handleSavepoint(ctx context.Context, safeSession *econtext.SafeSession, sql string, planType string, logStats *logstats.LogStats, nonTxResponse func(query string) (*sqltypes.Result, error), ignoreMaxMemoryRows bool) (*sqltypes.Result, error) { execStart := time.Now() logStats.PlanTime = execStart.Sub(logStats.StartTime) logStats.ShardQueries = uint64(len(safeSession.ShardSessions)) @@ -637,7 +639,7 @@ func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession // If no transaction exists on any of the shard sessions, // then savepoint does not need to be executed, it will be only stored in the session // and later will be executed when a transaction is started. - if !safeSession.isTxOpen() { + if !safeSession.IsTxOpen() { if safeSession.InTransaction() { // Storing, as this needs to be executed just after starting transaction on the shard. safeSession.StoreSavepoint(sql) @@ -645,7 +647,7 @@ func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession } return nonTxResponse(sql) } - orig := safeSession.commitOrder + orig := safeSession.GetCommitOrder() qr, err := e.executeSPInAllSessions(ctx, safeSession, sql, ignoreMaxMemoryRows) safeSession.SetCommitOrder(orig) if err != nil { @@ -657,7 +659,7 @@ func (e *Executor) handleSavepoint(ctx context.Context, safeSession *SafeSession // executeSPInAllSessions function executes the savepoint query in all open shard sessions (pre, normal and post) // which has non-zero transaction id (i.e. an open transaction on the shard connection). -func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *SafeSession, sql string, ignoreMaxMemoryRows bool) (*sqltypes.Result, error) { +func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *econtext.SafeSession, sql string, ignoreMaxMemoryRows bool) (*sqltypes.Result, error) { var qr *sqltypes.Result var errs []error for _, co := range []vtgatepb.CommitOrder{vtgatepb.CommitOrder_PRE, vtgatepb.CommitOrder_NORMAL, vtgatepb.CommitOrder_POST} { @@ -665,7 +667,7 @@ func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *Safe var rss []*srvtopo.ResolvedShard var queries []*querypb.BoundQuery - for _, shardSession := range safeSession.getSessions() { + for _, shardSession := range safeSession.GetSessions() { // This will avoid executing savepoint on reserved connections // which has no open transaction. if shardSession.TransactionId == 0 { @@ -718,11 +720,11 @@ func (e *Executor) handleKill(ctx context.Context, mysqlCtx vtgateservice.MySQLC // CloseSession releases the current connection, which rollbacks open transactions and closes reserved connections. // It is called then the MySQL servers closes the connection to its client. -func (e *Executor) CloseSession(ctx context.Context, safeSession *SafeSession) error { +func (e *Executor) CloseSession(ctx context.Context, safeSession *econtext.SafeSession) error { return e.txConn.ReleaseAll(ctx, safeSession) } -func (e *Executor) setVitessMetadata(ctx context.Context, name, value string) error { +func (e *Executor) SetVitessMetadata(ctx context.Context, name, value string) error { // TODO(kalfonso): move to its own acl check and consolidate into an acl component that can handle multiple operations (vschema, metadata) user := callerid.ImmediateCallerIDFromContext(ctx) allowed := vschemaacl.Authorized(user) @@ -741,7 +743,7 @@ func (e *Executor) setVitessMetadata(ctx context.Context, name, value string) er return ts.UpsertMetadata(ctx, name, value) } -func (e *Executor) showVitessMetadata(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { +func (e *Executor) ShowVitessMetadata(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { ts, err := e.serv.GetTopoServer() if err != nil { return nil, err @@ -774,7 +776,7 @@ func (e *Executor) showVitessMetadata(ctx context.Context, filter *sqlparser.Sho type tabletFilter func(tablet *topodatapb.Tablet, servingState string, primaryTermStartTime int64) bool -func (e *Executor) showShards(ctx context.Context, filter *sqlparser.ShowFilter, destTabletType topodatapb.TabletType) (*sqltypes.Result, error) { +func (e *Executor) ShowShards(ctx context.Context, filter *sqlparser.ShowFilter, destTabletType topodatapb.TabletType) (*sqltypes.Result, error) { showVitessShardsFilters := func(filter *sqlparser.ShowFilter) ([]func(string) bool, []func(string, *topodatapb.ShardReference) bool) { keyspaceFilters := []func(string) bool{} shardFilters := []func(string, *topodatapb.ShardReference) bool{} @@ -858,7 +860,7 @@ func (e *Executor) showShards(ctx context.Context, filter *sqlparser.ShowFilter, }, nil } -func (e *Executor) showTablets(filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { +func (e *Executor) ShowTablets(filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { getTabletFilters := func(filter *sqlparser.ShowFilter) []tabletFilter { var filters []tabletFilter @@ -931,7 +933,7 @@ func (e *Executor) showTablets(filter *sqlparser.ShowFilter) (*sqltypes.Result, }, nil } -func (e *Executor) showVitessReplicationStatus(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { +func (e *Executor) ShowVitessReplicationStatus(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { ctx, cancel := context.WithTimeout(ctx, healthCheckTimeout) defer cancel() rows := [][]sqltypes.Value{} @@ -1078,26 +1080,14 @@ func (e *Executor) SaveVSchema(vschema *vindexes.VSchema, stats *VSchemaStats) { // ParseDestinationTarget parses destination target string and sets default keyspace if possible. func (e *Executor) ParseDestinationTarget(targetString string) (string, topodatapb.TabletType, key.Destination, error) { - destKeyspace, destTabletType, dest, err := topoproto.ParseDestination(targetString, defaultTabletType) - // Set default keyspace - if destKeyspace == "" && len(e.VSchema().Keyspaces) == 1 { - for k := range e.VSchema().Keyspaces { - destKeyspace = k - } - } - return destKeyspace, destTabletType, dest, err -} - -type iQueryOption interface { - cachePlan() bool - getSelectLimit() int + return econtext.ParseDestinationTarget(targetString, defaultTabletType, e.VSchema()) } // getPlan computes the plan for the given query. If one is in // the cache, it reuses it. func (e *Executor) getPlan( ctx context.Context, - vcursor *vcursorImpl, + vcursor *econtext.VCursorImpl, query string, stmt sqlparser.Statement, comments sqlparser.MarginComments, @@ -1135,10 +1125,10 @@ func (e *Executor) getPlan( reservedVars, bindVars, parameterize, - vcursor.keyspace, - vcursor.safeSession.getSelectLimit(), + vcursor.GetKeyspace(), + vcursor.SafeSession.GetSelectLimit(), setVarComment, - vcursor.safeSession.SystemVariables, + vcursor.GetSystemVariablesCopy(), vcursor.GetForeignKeyChecksState(), vcursor, ) @@ -1157,9 +1147,9 @@ func (e *Executor) getPlan( return e.cacheAndBuildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds, logStats) } -func (e *Executor) hashPlan(ctx context.Context, vcursor *vcursorImpl, query string) PlanCacheKey { +func (e *Executor) hashPlan(ctx context.Context, vcursor *econtext.VCursorImpl, query string) PlanCacheKey { hasher := vthash.New256() - vcursor.keyForPlan(ctx, query, hasher) + vcursor.KeyForPlan(ctx, query, hasher) var planKey PlanCacheKey hasher.Sum(planKey[:0]) @@ -1168,7 +1158,7 @@ func (e *Executor) hashPlan(ctx context.Context, vcursor *vcursorImpl, query str func (e *Executor) buildStatement( ctx context.Context, - vcursor *vcursorImpl, + vcursor *econtext.VCursorImpl, query string, stmt sqlparser.Statement, reservedVars *sqlparser.ReservedVars, @@ -1183,8 +1173,7 @@ func (e *Executor) buildStatement( return nil, err } - plan.Warnings = vcursor.warnings - vcursor.warnings = nil + plan.Warnings = vcursor.GetAndEmptyWarnings() err = e.checkThatPlanIsValid(stmt, plan) return plan, err @@ -1192,14 +1181,14 @@ func (e *Executor) buildStatement( func (e *Executor) cacheAndBuildStatement( ctx context.Context, - vcursor *vcursorImpl, + vcursor *econtext.VCursorImpl, query string, stmt sqlparser.Statement, reservedVars *sqlparser.ReservedVars, bindVarNeeds *sqlparser.BindVarNeeds, logStats *logstats.LogStats, ) (*engine.Plan, error) { - planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.cachePlan() + planCachable := sqlparser.CachePlan(stmt) && vcursor.CachePlan() if planCachable { planKey := e.hashPlan(ctx, vcursor, query) @@ -1217,7 +1206,7 @@ func (e *Executor) canNormalizeStatement(stmt sqlparser.Statement, setVarComment return sqlparser.CanNormalize(stmt) || setVarComment != "" } -func prepareSetVarComment(vcursor *vcursorImpl, stmt sqlparser.Statement) (string, error) { +func prepareSetVarComment(vcursor *econtext.VCursorImpl, stmt sqlparser.Statement) (string, error) { if vcursor == nil || vcursor.Session().InReservedConn() { return "", nil } @@ -1358,7 +1347,7 @@ func isValidPayloadSize(query string) bool { } // Prepare executes a prepare statements. -func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (fld []*querypb.Field, err error) { +func (e *Executor) Prepare(ctx context.Context, method string, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (fld []*querypb.Field, err error) { logStats := logstats.NewLogStats(ctx, method, sql, safeSession.GetSessionUUID(), bindVars) fld, err = e.prepare(ctx, safeSession, sql, bindVars, logStats) logStats.Error = err @@ -1377,7 +1366,7 @@ func (e *Executor) Prepare(ctx context.Context, method string, safeSession *Safe return fld, err } -func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) { +func (e *Executor) prepare(ctx context.Context, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) { // Start an implicit transaction if necessary. if !safeSession.Autocommit && !safeSession.InTransaction() { if err := e.txConn.Begin(ctx, safeSession, nil); err != nil { @@ -1413,9 +1402,41 @@ func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql st return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unrecognized prepare statement: %s", sql) } -func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) { +func (e *Executor) initVConfig(warnOnShardedOnly bool, pv plancontext.PlannerVersion) { + connCollation := collations.Unknown + if gw, isTabletGw := e.resolver.resolver.GetGateway().(*TabletGateway); isTabletGw { + connCollation = gw.DefaultConnCollation() + } + if connCollation == collations.Unknown { + connCollation = e.env.CollationEnv().DefaultConnectionCharset() + } + + e.vConfig = econtext.VCursorConfig{ + Collation: connCollation, + DefaultTabletType: defaultTabletType, + PlannerVersion: pv, + + QueryTimeout: queryTimeout, + MaxMemoryRows: maxMemoryRows, + + SetVarEnabled: sysVarSetEnabled, + EnableViews: enableViews, + ForeignKeyMode: fkMode(foreignKeyMode), + EnableShardRouting: enableShardRouting, + WarnShardedOnly: warnOnShardedOnly, + + DBDDLPlugin: dbDDLPlugin, + + WarmingReadsPercent: e.warmingReadsPercent, + WarmingReadsTimeout: warmingReadsQueryTimeout, + WarmingReadsChannel: e.warmingReadsChannel, + } +} + +func (e *Executor) handlePrepare(ctx context.Context, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) { query, comments := sqlparser.SplitMarginComments(sql) - vcursor, _ := newVCursorImpl(safeSession, comments, e, logStats, e.vm, e.VSchema(), e.resolver.resolver, e.serv, e.warnShardedOnly, e.pv) + + vcursor, _ := econtext.NewVCursorImpl(safeSession, comments, e, logStats, e.vm, e.VSchema(), e.resolver.resolver, e.serv, nullResultsObserver{}, e.vConfig) stmt, reservedVars, err := parseAndValidateQuery(query, e.env.Parser()) if err != nil { @@ -1464,17 +1485,17 @@ func parseAndValidateQuery(query string, parser *sqlparser.Parser) (sqlparser.St } // ExecuteMultiShard implements the IExecutor interface -func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver resultsObserver) (qr *sqltypes.Result, errs []error) { +func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *econtext.SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver econtext.ResultsObserver) (qr *sqltypes.Result, errs []error) { return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows, resultsObserver) } // StreamExecuteMulti implements the IExecutor interface -func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, resultsObserver resultsObserver) []error { +func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *econtext.SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, resultsObserver econtext.ResultsObserver) []error { return e.scatterConn.StreamExecuteMulti(ctx, primitive, query, rss, vars, session, autocommit, callback, resultsObserver) } // ExecuteLock implements the IExecutor interface -func (e *Executor) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { +func (e *Executor) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *econtext.SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { return e.scatterConn.ExecuteLock(ctx, rs, query, session, lockFuncType) } @@ -1585,25 +1606,25 @@ func getTabletThrottlerStatus(tabletHostPort string) (string, error) { } // ReleaseLock implements the IExecutor interface -func (e *Executor) ReleaseLock(ctx context.Context, session *SafeSession) error { +func (e *Executor) ReleaseLock(ctx context.Context, session *econtext.SafeSession) error { return e.txConn.ReleaseLock(ctx, session) } -// planPrepareStmt implements the IExecutor interface -func (e *Executor) planPrepareStmt(ctx context.Context, vcursor *vcursorImpl, query string) (*engine.Plan, sqlparser.Statement, error) { +// PlanPrepareStmt implements the IExecutor interface +func (e *Executor) PlanPrepareStmt(ctx context.Context, vcursor *econtext.VCursorImpl, query string) (*engine.Plan, sqlparser.Statement, error) { stmt, reservedVars, err := parseAndValidateQuery(query, e.env.Parser()) if err != nil { return nil, nil, err } // creating this log stats to not interfere with the original log stats. - lStats := logstats.NewLogStats(ctx, "prepare", query, vcursor.safeSession.SessionUUID, nil) + lStats := logstats.NewLogStats(ctx, "prepare", query, vcursor.Session().GetSessionUUID(), nil) plan, err := e.getPlan( ctx, vcursor, query, sqlparser.Clone(stmt), - vcursor.marginComments, + vcursor.GetMarginComments(), map[string]*querypb.BindVariable{}, reservedVars, /* normalize */ false, @@ -1625,7 +1646,7 @@ func (e *Executor) Close() { e.plans.Close() } -func (e *Executor) environment() *vtenv.Environment { +func (e *Executor) Environment() *vtenv.Environment { return e.env } @@ -1637,6 +1658,10 @@ func (e *Executor) UnresolvedTransactions(ctx context.Context, targets []*queryp return e.txConn.UnresolvedTransactions(ctx, targets) } +func (e *Executor) AddWarningCount(name string, count int64) { + warnings.Add(name, count) +} + type ( errorTransformer interface { TransformError(err error) error @@ -1647,3 +1672,16 @@ type ( func (nullErrorTransformer) TransformError(err error) error { return err } + +func fkMode(foreignkey string) vschemapb.Keyspace_ForeignKeyMode { + switch foreignkey { + case "disallow": + return vschemapb.Keyspace_disallow + case "managed": + return vschemapb.Keyspace_managed + case "unmanaged": + return vschemapb.Keyspace_unmanaged + + } + return vschemapb.Keyspace_unspecified +} diff --git a/go/vt/vtgate/executor_ddl_test.go b/go/vt/vtgate/executor_ddl_test.go index b0e2cf4b873..bf117856e08 100644 --- a/go/vt/vtgate/executor_ddl_test.go +++ b/go/vt/vtgate/executor_ddl_test.go @@ -21,6 +21,7 @@ import ( "testing" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "github.com/stretchr/testify/require" ) @@ -56,7 +57,7 @@ func TestDDLFlags(t *testing.T) { for _, testcase := range testcases { t.Run(fmt.Sprintf("%s-%v-%v", testcase.sql, testcase.enableDirectDDL, testcase.enableOnlineDDL), func(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) enableDirectDDL.Set(testcase.enableDirectDDL) enableOnlineDDL.Set(testcase.enableOnlineDDL) _, err := executor.Execute(ctx, nil, "TestDDLFlags", session, testcase.sql, nil) diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index 0ebf5f80824..792e197f48d 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -25,6 +25,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "vitess.io/vitess/go/mysql/config" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" @@ -135,7 +137,6 @@ func TestUpdateEqual(t *testing.T) { func TestUpdateFromSubQuery(t *testing.T) { executor, sbc1, sbc2, _, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 logChan := executor.queryLogger.Subscribe("Test") defer executor.queryLogger.Unsubscribe(logChan) @@ -234,7 +235,7 @@ func TestUpdateInTransactionLookupDefaultReadLock(t *testing.T) { )} executor, sbc1, sbc2, sbcLookup, ctx := createCustomExecutorSetValues(t, executorVSchema, res) - safeSession := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + safeSession := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, err := executorExecSession(ctx, executor, "update t2_lookup set lu_col = 5 where nv_lu_col = 2", @@ -296,7 +297,7 @@ func TestUpdateInTransactionLookupExclusiveReadLock(t *testing.T) { )} executor, sbc1, sbc2, sbcLookup, ctx := createCustomExecutorSetValues(t, executorVSchema, res) - safeSession := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + safeSession := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, err := executorExecSession(ctx, executor, "update t2_lookup set lu_col = 5 where erl_lu_col = 2", @@ -358,7 +359,7 @@ func TestUpdateInTransactionLookupSharedReadLock(t *testing.T) { )} executor, sbc1, sbc2, sbcLookup, ctx := createCustomExecutorSetValues(t, executorVSchema, res) - safeSession := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + safeSession := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, err := executorExecSession(ctx, executor, "update t2_lookup set lu_col = 5 where srl_lu_col = 2", @@ -420,7 +421,7 @@ func TestUpdateInTransactionLookupNoReadLock(t *testing.T) { )} executor, sbc1, sbc2, sbcLookup, ctx := createCustomExecutorSetValues(t, executorVSchema, res) - safeSession := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + safeSession := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, err := executorExecSession(ctx, executor, "update t2_lookup set lu_col = 5 where nrl_lu_col = 2", @@ -2066,7 +2067,7 @@ func TestInsertPartialFail1(t *testing.T) { context.Background(), nil, "TestExecute", - NewSafeSession(&vtgatepb.Session{InTransaction: true}), + econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}), "insert into user(id, v, name) values (1, 2, 'myname')", nil, ) @@ -2082,7 +2083,7 @@ func TestInsertPartialFail2(t *testing.T) { // Make the second DML fail, it should result in a rollback. sbc1.MustFailExecute[sqlparser.StmtInsert] = 1 - safeSession := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + safeSession := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, err := executor.Execute( context.Background(), nil, @@ -2656,7 +2657,7 @@ func TestReservedConnDML(t *testing.T) { logChan := executor.queryLogger.Subscribe("TestReservedConnDML") defer executor.queryLogger.Unsubscribe(logChan) - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true}) _, err := executor.Execute(ctx, nil, "TestReservedConnDML", session, "use "+KsTestUnsharded, nil) require.NoError(t, err) @@ -2708,7 +2709,7 @@ func TestStreamingDML(t *testing.T) { logChan := executor.queryLogger.Subscribe(method) defer executor.queryLogger.Unsubscribe(logChan) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) tcases := []struct { query string @@ -2792,7 +2793,7 @@ func TestPartialVindexInsertQueryFailure(t *testing.T) { logChan := executor.queryLogger.Subscribe("Test") defer executor.queryLogger.Unsubscribe(logChan) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) require.True(t, session.GetAutocommit()) require.False(t, session.InTransaction()) @@ -2845,7 +2846,7 @@ func TestPartialVindexInsertQueryFailureAutoCommit(t *testing.T) { logChan := executor.queryLogger.Subscribe("Test") defer executor.queryLogger.Unsubscribe(logChan) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) require.True(t, session.GetAutocommit()) require.False(t, session.InTransaction()) @@ -2886,7 +2887,7 @@ func TestPartialVindexInsertQueryFailureAutoCommit(t *testing.T) { func TestMultiInternalSavepoint(t *testing.T) { executor, sbc1, sbc2, _, ctx := createExecutorEnv(t) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) _, err := executorExecSession(ctx, executor, "begin", nil, session.Session) require.NoError(t, err) @@ -2935,7 +2936,7 @@ func TestInsertSelectFromDual(t *testing.T) { logChan := executor.queryLogger.Subscribe("TestInsertSelect") defer executor.queryLogger.Unsubscribe(logChan) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) query := "insert into user(id, v, name) select 1, 2, 'myname' from dual" wantQueries := []*querypb.BoundQuery{{ @@ -2990,7 +2991,7 @@ func TestInsertSelectFromTable(t *testing.T) { logChan := executor.queryLogger.Subscribe("TestInsertSelect") defer executor.queryLogger.Unsubscribe(logChan) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) query := "insert into user(id, name) select c1, c2 from music" wantQueries := []*querypb.BoundQuery{{ @@ -3146,7 +3147,7 @@ func TestSessionRowsAffected(t *testing.T) { method := t.Name() executor, _, sbc4060, _, ctx := createExecutorEnv(t) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) // start the transaction _, err := executor.Execute(ctx, nil, method, session, "begin", nil) diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index 332139c4a78..2ee3425209f 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -28,6 +28,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/constants/sidecar" "vitess.io/vitess/go/sqltypes" @@ -307,7 +309,7 @@ func executorExecSession(ctx context.Context, executor *Executor, sql string, bv ctx, nil, "TestExecute", - NewSafeSession(session), + econtext.NewSafeSession(session), sql, bv) } @@ -320,7 +322,7 @@ func executorPrepare(ctx context.Context, executor *Executor, session *vtgatepb. return executor.Prepare( ctx, "TestExecute", - NewSafeSession(session), + econtext.NewSafeSession(session), sql, bv) } @@ -331,7 +333,7 @@ func executorStream(ctx context.Context, executor *Executor, sql string) (qr *sq ctx, nil, "TestExecuteStream", - NewSafeSession(nil), + econtext.NewSafeSession(nil), sql, nil, func(qr *sqltypes.Result) error { diff --git a/go/vt/vtgate/executor_scatter_stats_test.go b/go/vt/vtgate/executor_scatter_stats_test.go index 84dd2744e8b..b665f850a23 100644 --- a/go/vt/vtgate/executor_scatter_stats_test.go +++ b/go/vt/vtgate/executor_scatter_stats_test.go @@ -24,12 +24,13 @@ import ( "github.com/stretchr/testify/require" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" ) func TestScatterStatsWithNoScatterQuery(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) _, err := executor.Execute(ctx, nil, "TestExecutorResultsExceeded", session, "select * from main1", nil) require.NoError(t, err) @@ -41,7 +42,7 @@ func TestScatterStatsWithNoScatterQuery(t *testing.T) { func TestScatterStatsWithSingleScatterQuery(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) _, err := executor.Execute(ctx, nil, "TestExecutorResultsExceeded", session, "select * from user", nil) require.NoError(t, err) @@ -53,7 +54,7 @@ func TestScatterStatsWithSingleScatterQuery(t *testing.T) { func TestScatterStatsHttpWriting(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) _, err := executor.Execute(ctx, nil, "TestExecutorResultsExceeded", session, "select * from user", nil) require.NoError(t, err) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 8ba89d25daf..86aafaefba4 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -30,6 +30,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + _flag "vitess.io/vitess/go/internal/flag" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -59,7 +61,7 @@ func TestSelectNext(t *testing.T) { }} // Autocommit - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) _, err := executor.Execute(context.Background(), nil, "TestSelectNext", session, query, bv) require.NoError(t, err) @@ -69,7 +71,7 @@ func TestSelectNext(t *testing.T) { sbclookup.Queries = nil // Txn - session = NewAutocommitSession(&vtgatepb.Session{}) + session = econtext.NewAutocommitSession(&vtgatepb.Session{}) session.Session.InTransaction = true _, err = executor.Execute(context.Background(), nil, "TestSelectNext", session, query, bv) require.NoError(t, err) @@ -80,7 +82,7 @@ func TestSelectNext(t *testing.T) { sbclookup.Queries = nil // Reserve - session = NewAutocommitSession(&vtgatepb.Session{}) + session = econtext.NewAutocommitSession(&vtgatepb.Session{}) session.Session.InReservedConn = true _, err = executor.Execute(context.Background(), nil, "TestSelectNext", session, query, bv) require.NoError(t, err) @@ -91,7 +93,7 @@ func TestSelectNext(t *testing.T) { sbclookup.Queries = nil // Reserve and Txn - session = NewAutocommitSession(&vtgatepb.Session{}) + session = econtext.NewAutocommitSession(&vtgatepb.Session{}) session.Session.InReservedConn = true session.Session.InTransaction = true _, err = executor.Execute(context.Background(), nil, "TestSelectNext", session, query, bv) @@ -107,7 +109,7 @@ func TestSelectDBA(t *testing.T) { query := "select * from INFORMATION_SCHEMA.foo" _, err := executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -117,7 +119,7 @@ func TestSelectDBA(t *testing.T) { sbc1.Queries = nil query = "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES ist WHERE ist.table_schema = 'performance_schema' AND ist.table_name = 'foo'" _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -133,7 +135,7 @@ func TestSelectDBA(t *testing.T) { sbc1.Queries = nil query = "select 1 from information_schema.table_constraints where constraint_schema = 'vt_ks' and table_name = 'user'" _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -149,7 +151,7 @@ func TestSelectDBA(t *testing.T) { sbc1.Queries = nil query = "select 1 from information_schema.table_constraints where constraint_schema = 'vt_ks'" _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -167,7 +169,7 @@ func TestSystemVariablesMySQLBelow80(t *testing.T) { executor.normalize = true setVarEnabled = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: "TestExecutor"}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: "TestExecutor"}) sbc1.SetResults([]*sqltypes.Result{{ Fields: []*querypb.Field{ @@ -199,12 +201,8 @@ func TestSystemVariablesMySQLBelow80(t *testing.T) { func TestSystemVariablesWithSetVarDisabled(t *testing.T) { executor, sbc1, _, _, _ := createCustomExecutor(t, "{}", "8.0.0") executor.normalize = true - - setVarEnabled = false - defer func() { - setVarEnabled = true - }() - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: "TestExecutor"}) + executor.vConfig.SetVarEnabled = false + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: "TestExecutor"}) sbc1.SetResults([]*sqltypes.Result{{ Fields: []*querypb.Field{ @@ -237,7 +235,7 @@ func TestSetSystemVariablesTx(t *testing.T) { executor, sbc1, _, _, _ := createCustomExecutor(t, "{}", "8.0.1") executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: "TestExecutor"}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: "TestExecutor"}) _, err := executor.Execute(context.Background(), nil, "TestBegin", session, "begin", map[string]*querypb.BindVariable{}) require.NoError(t, err) @@ -283,7 +281,7 @@ func TestSetSystemVariables(t *testing.T) { executor, _, _, lookup, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: KsTestUnsharded, SystemVariables: map[string]string{}}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: KsTestUnsharded, SystemVariables: map[string]string{}}) // Set @@sql_mode and execute a select statement. We should have SET_VAR in the select statement @@ -394,7 +392,7 @@ func TestSetSystemVariablesWithReservedConnection(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, SystemVariables: map[string]string{}}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, SystemVariables: map[string]string{}}) sbc1.SetResults([]*sqltypes.Result{{ Fields: []*querypb.Field{ @@ -437,7 +435,7 @@ func TestSelectVindexFunc(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) query := "select * from hash_index where id = 1" - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) _, err := executor.Execute(context.Background(), nil, "TestSelectVindexFunc", session, query, nil) require.ErrorContains(t, err, "VT09005: no database selected") @@ -450,7 +448,7 @@ func TestCreateTableValidTimestamp(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor", SystemVariables: map[string]string{"sql_mode": "ALLOW_INVALID_DATES"}}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor", SystemVariables: map[string]string{"sql_mode": "ALLOW_INVALID_DATES"}}) query := "create table aa(t timestamp default 0)" _, err := executor.Execute(context.Background(), nil, "TestSelect", session, query, map[string]*querypb.BindVariable{}) @@ -468,11 +466,10 @@ func TestCreateTableValidTimestamp(t *testing.T) { func TestGen4SelectDBA(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 query := "select * from INFORMATION_SCHEMA.TABLE_CONSTRAINTS" _, err := executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -483,7 +480,7 @@ func TestGen4SelectDBA(t *testing.T) { sbc1.Queries = nil query = "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES ist WHERE ist.table_schema = 'performance_schema' AND ist.table_name = 'foo'" _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -501,7 +498,7 @@ func TestGen4SelectDBA(t *testing.T) { sbc1.Queries = nil query = "select 1 from information_schema.table_constraints where constraint_schema = 'vt_ks' and table_name = 'user'" _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -519,7 +516,7 @@ func TestGen4SelectDBA(t *testing.T) { sbc1.Queries = nil query = "select 1 from information_schema.table_constraints where constraint_schema = 'vt_ks'" - _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}) + _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), query, map[string]*querypb.BindVariable{}) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{{ Sql: "select :vtg1 /* INT64 */ from information_schema.table_constraints where constraint_schema = :__vtschemaname /* VARCHAR */", @@ -534,7 +531,7 @@ func TestGen4SelectDBA(t *testing.T) { sbc1.Queries = nil query = "select t.table_schema,t.table_name,c.column_name,c.column_type from tables t join columns c on c.table_schema = t.table_schema and c.table_name = t.table_name where t.table_schema = 'TestExecutor' and c.table_schema = 'TestExecutor' order by t.table_schema,t.table_name,c.column_name" _, err = executor.Execute(context.Background(), nil, "TestSelectDBA", - NewSafeSession(&vtgatepb.Session{TargetString: "information_schema"}), + econtext.NewSafeSession(&vtgatepb.Session{TargetString: "information_schema"}), query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) @@ -651,7 +648,7 @@ func TestStreamBuffering(t *testing.T) { context.Background(), nil, "TestStreamBuffering", - NewSafeSession(session), + econtext.NewSafeSession(session), "select id from music_user_map where id = 1", nil, func(qr *sqltypes.Result) error { @@ -723,7 +720,7 @@ func TestStreamLimitOffset(t *testing.T) { context.Background(), nil, "TestStreamLimitOffset", - NewSafeSession(session), + econtext.NewSafeSession(session), "select id, textcol from user order by id limit 2 offset 2", nil, func(qr *sqltypes.Result) error { @@ -1083,7 +1080,7 @@ func TestSelectDatabase(t *testing.T) { newSession := &vtgatepb.Session{ TargetString: "@primary", } - session := NewSafeSession(newSession) + session := econtext.NewSafeSession(newSession) session.TargetString = "TestExecutor@primary" result, err := executor.Execute( context.Background(), @@ -1283,7 +1280,6 @@ func TestSelectEqual(t *testing.T) { func TestSelectINFromOR(t *testing.T) { executor, sbc1, _, _, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 session := &vtgatepb.Session{ TargetString: "@primary", @@ -2951,7 +2947,7 @@ func TestSubQueryAndQueryWithLimit(t *testing.T) { sbc1.SetResults(result1) sbc2.SetResults(result2) - exec(executor, NewSafeSession(&vtgatepb.Session{ + exec(executor, econtext.NewSafeSession(&vtgatepb.Session{ TargetString: "@primary", }), "select id1, id2 from t1 where id1 >= ( select id1 from t1 order by id1 asc limit 1) limit 100") require.Equal(t, 2, len(sbc1.Queries)) @@ -3000,7 +2996,7 @@ func TestSelectUsingMultiEqualOnLookupColumn(t *testing.T) { }}, }}) - result, err := exec(executor, NewSafeSession(&vtgatepb.Session{ + result, err := exec(executor, econtext.NewSafeSession(&vtgatepb.Session{ TargetString: KsTestSharded, }), "select nv_lu_col, other from t2_lookup WHERE (nv_lu_col = 1 AND other = 'bar') OR (nv_lu_col = 2 AND other = 'baz') OR (nv_lu_col = 3 AND other = 'qux') OR (nv_lu_col = 4 AND other = 'brz') OR (nv_lu_col = 5 AND other = 'brz')") @@ -3197,7 +3193,7 @@ func TestSelectWithUnionAll(t *testing.T) { func TestSelectLock(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) - session := NewSafeSession(nil) + session := econtext.NewSafeSession(nil) session.Session.InTransaction = true session.ShardSessions = []*vtgatepb.Session_ShardSession{{ Target: &querypb.Target{ @@ -3265,7 +3261,7 @@ func TestLockReserve(t *testing.T) { "select release_lock('lock name') from dual", } - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) for _, sql := range tcases { t.Run(sql, func(t *testing.T) { @@ -3283,7 +3279,7 @@ func TestLockReserve(t *testing.T) { func TestSelectFromInformationSchema(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) - session := NewSafeSession(nil) + session := econtext.NewSafeSession(nil) // check failure when trying to query two keyspaces _, err := exec(executor, session, "SELECT B.TABLE_NAME FROM INFORMATION_SCHEMA.TABLES AS A, INFORMATION_SCHEMA.COLUMNS AS B WHERE A.TABLE_SCHEMA = 'TestExecutor' AND A.TABLE_SCHEMA = 'TestXBadSharding'") @@ -3410,8 +3406,8 @@ func TestSelectScatterFails(t *testing.T) { func TestGen4SelectStraightJoin(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) query := "select u.id from user u straight_join user2 u2 on u.id = u2.id" _, err := executor.Execute(context.Background(), nil, "TestGen4SelectStraightJoin", @@ -3432,9 +3428,8 @@ func TestGen4SelectStraightJoin(t *testing.T) { func TestGen4MultiColumnVindexEqual(t *testing.T) { executor, sbc1, sbc2, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) query := "select * from user_region where cola = 1 and colb = 2" _, err := executor.Execute(context.Background(), nil, "TestGen4MultiColumnVindex", session, query, map[string]*querypb.BindVariable{}) require.NoError(t, err) @@ -3471,9 +3466,8 @@ func TestGen4MultiColumnVindexEqual(t *testing.T) { func TestGen4MultiColumnVindexIn(t *testing.T) { executor, sbc1, sbc2, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) query := "select * from user_region where cola IN (1,17984) and colb IN (2,3,4)" _, err := executor.Execute(context.Background(), nil, "TestGen4MultiColumnVindex", session, query, map[string]*querypb.BindVariable{}) require.NoError(t, err) @@ -3510,9 +3504,8 @@ func TestGen4MultiColumnVindexIn(t *testing.T) { func TestGen4MultiColMixedColComparision(t *testing.T) { executor, sbc1, sbc2, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) query := "select * from user_region where colb = 2 and cola IN (1,17984)" _, err := executor.Execute(context.Background(), nil, "TestGen4MultiColMixedColComparision", session, query, map[string]*querypb.BindVariable{}) require.NoError(t, err) @@ -3547,9 +3540,8 @@ func TestGen4MultiColMixedColComparision(t *testing.T) { func TestGen4MultiColBestVindexSel(t *testing.T) { executor, sbc1, sbc2, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) query := "select * from user_region where colb = 2 and cola IN (1,17984) and cola = 1" _, err := executor.Execute(context.Background(), nil, "TestGen4MultiColBestVindexSel", session, query, map[string]*querypb.BindVariable{}) require.NoError(t, err) @@ -3593,9 +3585,8 @@ func TestGen4MultiColBestVindexSel(t *testing.T) { func TestGen4MultiColMultiEqual(t *testing.T) { executor, sbc1, sbc2, _, _ := createExecutorEnv(t) executor.normalize = true - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) query := "select * from user_region where (cola,colb) in ((17984,2),(17984,3))" _, err := executor.Execute(context.Background(), nil, "TestGen4MultiColMultiEqual", session, query, map[string]*querypb.BindVariable{}) require.NoError(t, err) @@ -3615,7 +3606,6 @@ func TestGen4MultiColMultiEqual(t *testing.T) { func TestGen4SelectUnqualifiedReferenceTable(t *testing.T) { executor, sbc1, sbc2, sbclookup, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 query := "select * from zip_detail" session := &vtgatepb.Session{ @@ -3636,7 +3626,6 @@ func TestGen4SelectUnqualifiedReferenceTable(t *testing.T) { func TestGen4SelectQualifiedReferenceTable(t *testing.T) { executor, sbc1, sbc2, sbclookup, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 query := fmt.Sprintf("select * from %s.zip_detail", KsTestSharded) session := &vtgatepb.Session{ @@ -3657,7 +3646,6 @@ func TestGen4SelectQualifiedReferenceTable(t *testing.T) { func TestGen4JoinUnqualifiedReferenceTable(t *testing.T) { executor, sbc1, sbc2, sbclookup, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 query := "select * from user join zip_detail on user.zip_detail_id = zip_detail.id" session := &vtgatepb.Session{ @@ -3694,7 +3682,6 @@ func TestGen4JoinUnqualifiedReferenceTable(t *testing.T) { func TestGen4CrossShardJoinQualifiedReferenceTable(t *testing.T) { executor, sbc1, sbc2, sbclookup, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 query := "select user.id from user join TestUnsharded.zip_detail on user.zip_detail_id = TestUnsharded.zip_detail.id" session := &vtgatepb.Session{ @@ -3751,7 +3738,6 @@ func TestRegionRange(t *testing.T) { } executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - executor.pv = querypb.ExecuteOptions_Gen4 tcases := []struct { regionID int @@ -3769,7 +3755,7 @@ func TestRegionRange(t *testing.T) { for _, tcase := range tcases { t.Run(strconv.Itoa(tcase.regionID), func(t *testing.T) { sql := fmt.Sprintf("select * from user_region where cola = %d", tcase.regionID) - _, err := executor.Execute(context.Background(), nil, "TestRegionRange", NewAutocommitSession(&vtgatepb.Session{}), sql, nil) + _, err := executor.Execute(context.Background(), nil, "TestRegionRange", econtext.NewAutocommitSession(&vtgatepb.Session{}), sql, nil) require.NoError(t, err) count := 0 for _, sbc := range conns { @@ -3801,7 +3787,6 @@ func TestMultiCol(t *testing.T) { } executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - executor.pv = querypb.ExecuteOptions_Gen4 tcases := []struct { cola, colb, colc int @@ -3817,7 +3802,7 @@ func TestMultiCol(t *testing.T) { shards: []string{"20a0-"}, }} - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) for _, tcase := range tcases { t.Run(fmt.Sprintf("%d_%d_%d", tcase.cola, tcase.colb, tcase.colc), func(t *testing.T) { @@ -3882,7 +3867,6 @@ func TestMultiColPartial(t *testing.T) { } executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - executor.pv = querypb.ExecuteOptions_Gen4 tcases := []struct { where string @@ -3907,7 +3891,7 @@ func TestMultiColPartial(t *testing.T) { shards: []string{"20a0c0-"}, }} - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) for _, tcase := range tcases { t.Run(tcase.where, func(t *testing.T) { @@ -3946,7 +3930,6 @@ func TestSelectAggregationNoData(t *testing.T) { } executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - executor.pv = querypb.ExecuteOptions_Gen4 tcases := []struct { sql string @@ -4038,7 +4021,6 @@ func TestSelectAggregationData(t *testing.T) { } executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - executor.pv = querypb.ExecuteOptions_Gen4 tcases := []struct { sql string @@ -4196,8 +4178,7 @@ func TestSelectAggregationRandom(t *testing.T) { executor := createExecutor(ctx, serv, cell, resolver) defer executor.Close() - executor.pv = querypb.ExecuteOptions_Gen4 - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) rs, err := executor.Execute(context.Background(), nil, "TestSelectCFC", session, "select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil) require.NoError(t, err) @@ -4207,7 +4188,7 @@ func TestSelectAggregationRandom(t *testing.T) { func TestSelectDateTypes(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) qr, err := executor.Execute(context.Background(), nil, "TestSelectDateTypes", session, "select '2020-01-01' + interval month(date_sub(FROM_UNIXTIME(1234), interval 1 month))-1 month", nil) require.NoError(t, err) @@ -4218,7 +4199,7 @@ func TestSelectDateTypes(t *testing.T) { func TestSelectHexAndBit(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) qr, err := executor.Execute(context.Background(), nil, "TestSelectHexAndBit", session, "select 0b1001, b'1001', 0x9, x'09'", nil) require.NoError(t, err) @@ -4234,7 +4215,7 @@ func TestSelectHexAndBit(t *testing.T) { func TestSelectCFC(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) _, err := executor.Execute(context.Background(), nil, "TestSelectCFC", session, "select /*vt+ PLANNER=gen4 */ c2 from tbl_cfc where c1 like 'A%'", nil) require.NoError(t, err) @@ -4263,7 +4244,7 @@ func TestSelectView(t *testing.T) { require.NoError(t, err) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) _, err = executor.Execute(context.Background(), nil, "TestSelectView", session, "select * from user_details_view", nil) require.NoError(t, err) @@ -4304,7 +4285,7 @@ func TestWarmingReads(t *testing.T) { executor, primary, replica := createExecutorEnvWithPrimaryReplicaConn(t, ctx, 100) executor.normalize = true - session := NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) // Since queries on the replica will run in a separate go-routine, we need synchronization for the Queries field in the sandboxconn. replica.RequireQueriesLocking() @@ -4368,6 +4349,7 @@ func TestWarmingReads(t *testing.T) { // waitUntilQueryCount waits until the number of queries run on the tablet reach the specified count. func waitUntilQueryCount(t *testing.T, tab *sandboxconn.SandboxConn, count int) { + t.Helper() timeout := time.After(1 * time.Second) for { select { @@ -4428,7 +4410,7 @@ func TestStreamJoinQuery(t *testing.T) { func TestSysVarGlobalAndSession(t *testing.T) { executor, sbc1, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, SystemVariables: map[string]string{}}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, SystemVariables: map[string]string{}}) sbc1.SetResults([]*sqltypes.Result{ sqltypes.MakeTestResult(sqltypes.MakeTestFields("innodb_lock_wait_timeout", "uint64"), "20"), diff --git a/go/vt/vtgate/executor_set_test.go b/go/vt/vtgate/executor_set_test.go index 12e8e272bd7..62101639a11 100644 --- a/go/vt/vtgate/executor_set_test.go +++ b/go/vt/vtgate/executor_set_test.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/mysql/sqlerror" querypb "vitess.io/vitess/go/vt/proto/query" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/test/utils" @@ -266,7 +267,7 @@ func TestExecutorSet(t *testing.T) { }} for i, tcase := range testcases { t.Run(fmt.Sprintf("%d-%s", i, tcase.in), func(t *testing.T) { - session := NewSafeSession(&vtgatepb.Session{Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{Autocommit: true}) _, err := executorEnv.Execute(ctx, nil, "TestExecute", session, tcase.in, nil) if tcase.err == "" { require.NoError(t, err) @@ -374,7 +375,7 @@ func TestExecutorSetOp(t *testing.T) { }} for _, tcase := range testcases { t.Run(tcase.in, func(t *testing.T) { - session := NewAutocommitSession(&vtgatepb.Session{ + session := econtext.NewAutocommitSession(&vtgatepb.Session{ TargetString: "@primary", }) session.TargetString = KsTestUnsharded @@ -392,7 +393,7 @@ func TestExecutorSetMetadata(t *testing.T) { t.Run("Session 1", func(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) set := "set @@vitess_metadata.app_keyspace_v1= '1'" _, err := executor.Execute(ctx, nil, "TestExecute", session, set, nil) @@ -406,7 +407,7 @@ func TestExecutorSetMetadata(t *testing.T) { }() executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) set := "set @@vitess_metadata.app_keyspace_v1= '1'" _, err := executor.Execute(ctx, nil, "TestExecute", session, set, nil) @@ -469,7 +470,7 @@ func TestPlanExecutorSetUDV(t *testing.T) { }} for _, tcase := range testcases { t.Run(tcase.in, func(t *testing.T) { - session := NewSafeSession(&vtgatepb.Session{Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{Autocommit: true}) _, err := executor.Execute(ctx, nil, "TestExecute", session, tcase.in, nil) if err != nil { require.EqualError(t, err, tcase.err) @@ -515,7 +516,7 @@ func TestSetVar(t *testing.T) { executor, _, _, sbc, ctx := createCustomExecutor(t, "{}", "8.0.0") executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: KsTestUnsharded}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: KsTestUnsharded}) sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields("orig|new", "varchar|varchar"), @@ -554,7 +555,7 @@ func TestSetVarShowVariables(t *testing.T) { executor, _, _, sbc, ctx := createCustomExecutor(t, "{}", "8.0.0") executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: KsTestUnsharded}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{EnableSystemSettings: true, TargetString: KsTestUnsharded}) sbc.SetResults([]*sqltypes.Result{ // select query result for checking any change in system settings @@ -597,7 +598,7 @@ func TestExecutorSetAndSelect(t *testing.T) { sysVar: "tx_isolation", exp: `[[VARCHAR("READ-UNCOMMITTED")]]`, // this returns the value set in previous query. }} - session := NewAutocommitSession(&vtgatepb.Session{TargetString: KsTestUnsharded, EnableSystemSettings: true}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{TargetString: KsTestUnsharded, EnableSystemSettings: true}) for _, tcase := range testcases { t.Run(fmt.Sprintf("%s-%s", tcase.sysVar, tcase.val), func(t *testing.T) { sbc.ExecCount.Store(0) // reset the value @@ -631,7 +632,7 @@ func TestExecutorSetAndSelect(t *testing.T) { func TestExecutorTimeZone(t *testing.T) { e, _, _, _, ctx := createExecutorEnv(t) - session := NewAutocommitSession(&vtgatepb.Session{TargetString: KsTestUnsharded, EnableSystemSettings: true}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{TargetString: KsTestUnsharded, EnableSystemSettings: true}) session.SetSystemVariable("time_zone", "'+08:00'") qr, err := e.Execute(ctx, nil, "TestExecutorSetAndSelect", session, "select now()", nil) diff --git a/go/vt/vtgate/executor_stream_test.go b/go/vt/vtgate/executor_stream_test.go index b8cfeaf3cd5..a8500dd59c4 100644 --- a/go/vt/vtgate/executor_stream_test.go +++ b/go/vt/vtgate/executor_stream_test.go @@ -31,6 +31,7 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/vtenv" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/logstats" _ "vitess.io/vitess/go/vt/vtgate/vindexes" "vitess.io/vitess/go/vt/vttablet/sandboxconn" @@ -102,7 +103,7 @@ func executorStreamMessages(executor *Executor, sql string) (qr *sqltypes.Result ctx, nil, "TestExecuteStream", - NewSafeSession(session), + econtext.NewSafeSession(session), sql, nil, func(qr *sqltypes.Result) error { diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 3732a37d1d1..2b6d4710bce 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -36,6 +36,8 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" @@ -64,7 +66,7 @@ func TestExecutorResultsExceeded(t *testing.T) { warnMemoryRows = 3 defer func() { warnMemoryRows = save }() - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) initial := warnings.Counts()["ResultsExceeded"] @@ -88,7 +90,7 @@ func TestExecutorMaxMemoryRowsExceeded(t *testing.T) { maxMemoryRows = 3 defer func() { maxMemoryRows = save }() - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) result := sqltypes.MakeTestResult(sqltypes.MakeTestFields("col", "int64"), "1", "2", "3", "4") fn := func(r *sqltypes.Result) error { return nil @@ -122,7 +124,7 @@ func TestExecutorMaxMemoryRowsExceeded(t *testing.T) { func TestExecutorTransactionsNoAutoCommit(t *testing.T) { executor, _, _, sbclookup, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", SessionUUID: "suuid"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", SessionUUID: "suuid"}) logChan := executor.queryLogger.Subscribe("Test") defer executor.queryLogger.Unsubscribe(logChan) @@ -188,7 +190,7 @@ func TestExecutorTransactionsNoAutoCommit(t *testing.T) { } // Prevent use of non-primary if in_transaction is on. - session = NewSafeSession(&vtgatepb.Session{TargetString: "@primary", InTransaction: true}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", InTransaction: true}) _, err = executor.Execute(ctx, nil, "TestExecute", session, "use @replica", nil) require.EqualError(t, err, `can't execute the given command because you have an active transaction`) } @@ -205,7 +207,7 @@ func TestDirectTargetRewrites(t *testing.T) { } sql := "select database()" - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) assertQueries(t, sbclookup, []*querypb.BoundQuery{{ Sql: "select :__vtdbname as `database()` from dual", @@ -216,7 +218,7 @@ func TestDirectTargetRewrites(t *testing.T) { func TestExecutorTransactionsAutoCommit(t *testing.T) { executor, _, _, sbclookup, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true, SessionUUID: "suuid"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true, SessionUUID: "suuid"}) logChan := executor.queryLogger.Subscribe("Test") defer executor.queryLogger.Unsubscribe(logChan) @@ -270,7 +272,7 @@ func TestExecutorTransactionsAutoCommitStreaming(t *testing.T) { executor, _, _, sbclookup, ctx := createExecutorEnv(t) oltpOptions := &querypb.ExecuteOptions{Workload: querypb.ExecuteOptions_OLTP} - session := NewSafeSession(&vtgatepb.Session{ + session := econtext.NewSafeSession(&vtgatepb.Session{ TargetString: "@primary", Autocommit: true, Options: oltpOptions, @@ -339,7 +341,7 @@ func TestExecutorDeleteMetadata(t *testing.T) { }() executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) set := "set @@vitess_metadata.app_v1= '1'" _, err := executor.Execute(ctx, nil, "TestExecute", session, set, nil) @@ -367,7 +369,7 @@ func TestExecutorDeleteMetadata(t *testing.T) { func TestExecutorAutocommit(t *testing.T) { executor, _, _, sbclookup, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) logChan := executor.queryLogger.Subscribe("Test") defer executor.queryLogger.Unsubscribe(logChan) @@ -446,7 +448,7 @@ func TestExecutorAutocommit(t *testing.T) { // transition autocommit from 0 to 1 in the middle of a transaction. startCount = sbclookup.CommitCount.Load() - session = NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) _, err = executor.Execute(ctx, nil, "TestExecute", session, "begin", nil) require.NoError(t, err) _, err = executor.Execute(ctx, nil, "TestExecute", session, "update main1 set id=1", nil) @@ -468,7 +470,7 @@ func TestExecutorAutocommit(t *testing.T) { func TestExecutorShowColumns(t *testing.T) { executor, sbc1, sbc2, sbclookup, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: ""}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ""}) queries := []string{ "SHOW COLUMNS FROM `user` in `TestExecutor`", @@ -520,7 +522,7 @@ func assertMatchesNoOrder(t *testing.T, expected, got string) { func TestExecutorShow(t *testing.T) { executor, _, _, sbclookup, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) for _, query := range []string{"show vitess_keyspaces", "show keyspaces"} { qr, err := executor.Execute(ctx, nil, "TestExecute", session, query, nil) @@ -545,7 +547,7 @@ func TestExecutorShow(t *testing.T) { _, err = executor.Execute(ctx, nil, "TestExecute", session, "use @primary", nil) require.NoError(t, err) _, err = executor.Execute(ctx, nil, "TestExecute", session, "show tables", nil) - assert.EqualError(t, err, errNoKeyspace.Error(), "'show tables' should fail without a keyspace") + assert.EqualError(t, err, econtext.ErrNoKeyspace.Error(), "'show tables' should fail without a keyspace") assert.Empty(t, sbclookup.Queries, "sbclookup unexpectedly has queries already") showResults := &sqltypes.Result{ @@ -920,7 +922,7 @@ func TestExecutorShow(t *testing.T) { query = "show vschema vindexes on user" _, err = executor.Execute(ctx, nil, "TestExecute", session, query, nil) - wantErr := errNoKeyspace.Error() + wantErr := econtext.ErrNoKeyspace.Error() assert.EqualError(t, err, wantErr, query) query = "show vschema vindexes on TestExecutor.garbage" @@ -1024,7 +1026,7 @@ func TestExecutorShow(t *testing.T) { utils.MustMatch(t, wantqr, qr, fmt.Sprintf("%s, with a bad keyspace", query)) query = "show vschema tables" - session = NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) qr, err = executor.Execute(ctx, nil, "TestExecute", session, query, nil) require.NoError(t, err) wantqr = &sqltypes.Result{ @@ -1050,9 +1052,9 @@ func TestExecutorShow(t *testing.T) { utils.MustMatch(t, wantqr, qr, query) query = "show vschema tables" - session = NewSafeSession(&vtgatepb.Session{}) + session = econtext.NewSafeSession(&vtgatepb.Session{}) _, err = executor.Execute(ctx, nil, "TestExecute", session, query, nil) - want = errNoKeyspace.Error() + want = econtext.ErrNoKeyspace.Error() assert.EqualError(t, err, want, query) query = "show 10" @@ -1061,7 +1063,7 @@ func TestExecutorShow(t *testing.T) { assert.EqualError(t, err, want, query) query = "show vschema tables" - session = NewSafeSession(&vtgatepb.Session{TargetString: "no_such_keyspace"}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: "no_such_keyspace"}) _, err = executor.Execute(ctx, nil, "TestExecute", session, query, nil) want = "VT05003: unknown database 'no_such_keyspace' in vschema" assert.EqualError(t, err, want, query) @@ -1080,7 +1082,7 @@ func TestExecutorShow(t *testing.T) { func TestExecutorShowTargeted(t *testing.T) { executor, _, sbc2, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor/40-60"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor/40-60"}) queries := []string{ "show databases", @@ -1107,7 +1109,7 @@ func TestExecutorShowTargeted(t *testing.T) { func TestExecutorShowFromSystemSchema(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "mysql"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "mysql"}) _, err := executor.Execute(ctx, nil, "TestExecutorShowFromSystemSchema", session, "show tables", nil) require.NoError(t, err) @@ -1116,7 +1118,7 @@ func TestExecutorShowFromSystemSchema(t *testing.T) { func TestExecutorUse(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{Autocommit: true, TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{Autocommit: true, TargetString: "@primary"}) stmts := []string{ "use TestExecutor", @@ -1135,13 +1137,13 @@ func TestExecutorUse(t *testing.T) { utils.MustMatch(t, wantSession, session.Session, "session does not match") } - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{}), "use 1", nil) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{}), "use 1", nil) wantErr := "syntax error at position 6 near '1'" if err == nil || err.Error() != wantErr { t.Errorf("got: %v, want %v", err, wantErr) } - _, err = executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{}), "use UnexistentKeyspace", nil) + _, err = executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{}), "use UnexistentKeyspace", nil) require.EqualError(t, err, "VT05003: unknown database 'UnexistentKeyspace' in vschema") } @@ -1155,7 +1157,7 @@ func TestExecutorComment(t *testing.T) { wantResult := &sqltypes.Result{} for _, stmt := range stmts { - gotResult, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) + gotResult, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) if err != nil { t.Error(err) } @@ -1240,9 +1242,9 @@ func TestExecutorDDL(t *testing.T) { sbc2.ExecCount.Store(0) sbclookup.ExecCount.Store(0) stmtType := "DDL" - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) if tc.hasNoKeyspaceErr { - require.EqualError(t, err, errNoKeyspace.Error(), "expect query to fail: %q", stmt) + require.EqualError(t, err, econtext.ErrNoKeyspace.Error(), "expect query to fail: %q", stmt) stmtType = "" // For error case, plan is not generated to query log will not contain any stmtType. } else { require.NoError(t, err, "did not expect error for query: %q", stmt) @@ -1278,9 +1280,9 @@ func TestExecutorDDL(t *testing.T) { sbc1.ExecCount.Store(0) sbc2.ExecCount.Store(0) sbclookup.ExecCount.Store(0) - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: ""}), stmt.input, nil) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: ""}), stmt.input, nil) if stmt.hasErr { - require.EqualError(t, err, errNoKeyspace.Error(), "expect query to fail") + require.EqualError(t, err, econtext.ErrNoKeyspace.Error(), "expect query to fail") testQueryLog(t, executor, logChan, "TestExecute", "", stmt.input, 0) } else { require.NoError(t, err) @@ -1297,13 +1299,13 @@ func TestExecutorDDLFk(t *testing.T) { } for _, stmt := range stmts { - for _, fkMode := range []string{"allow", "disallow"} { - t.Run(stmt+fkMode, func(t *testing.T) { + for _, mode := range []string{"allow", "disallow"} { + t.Run(stmt+mode, func(t *testing.T) { executor, _, _, sbc, ctx := createExecutorEnv(t) sbc.ExecCount.Store(0) - foreignKeyMode = fkMode - _, err := executor.Execute(ctx, nil, mName, NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) - if fkMode == "allow" { + executor.vConfig.ForeignKeyMode = fkMode(mode) + _, err := executor.Execute(ctx, nil, mName, econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) + if mode == "allow" { require.NoError(t, err) require.EqualValues(t, 1, sbc.ExecCount.Load()) } else { @@ -1322,7 +1324,7 @@ func TestExecutorAlterVSchemaKeyspace(t *testing.T) { }() executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) vschemaUpdates := make(chan *vschemapb.SrvVSchema, 2) executor.serv.WatchSrvVSchema(ctx, executor.cell, func(vschema *vschemapb.SrvVSchema, err error) bool { @@ -1364,7 +1366,7 @@ func TestExecutorCreateVindexDDL(t *testing.T) { t.Fatalf("test_vindex should not exist in original vschema") } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema create vindex test_vindex using hash" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) require.NoError(t, err) @@ -1388,7 +1390,7 @@ func TestExecutorCreateVindexDDL(t *testing.T) { // Create a new vschema keyspace implicitly by creating a vindex with a different // target in the session // ksNew := "test_new_keyspace" - session = NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt = "alter vschema create vindex test_vindex2 using hash" _, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) if err != nil { @@ -1439,7 +1441,7 @@ func TestExecutorAddDropVschemaTableDDL(t *testing.T) { vschemaTables = append(vschemaTables, t) } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema add table test_table" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) require.NoError(t, err) @@ -1451,7 +1453,7 @@ func TestExecutorAddDropVschemaTableDDL(t *testing.T) { _ = waitForVschemaTables(t, ks, append([]string{"test_table", "test_table2"}, vschemaTables...), executor) // Should fail adding a table on a sharded keyspace - session = NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) stmt = "alter vschema add table test_table" _, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) require.EqualError(t, err, "add vschema table: unsupported on sharded keyspace TestExecutor") @@ -1470,7 +1472,7 @@ func TestExecutorVindexDDLACL(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) ks := "TestExecutor" - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) ctxRedUser := callerid.NewContext(ctx, &vtrpcpb.CallerID{}, &querypb.VTGateCallerID{Username: "redUser"}) ctxBlueUser := callerid.NewContext(ctx, &vtrpcpb.CallerID{}, &querypb.VTGateCallerID{Username: "blueUser"}) @@ -1515,7 +1517,7 @@ func TestExecutorVindexDDLACL(t *testing.T) { func TestExecutorUnrecognized(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{}), "invalid statement", nil) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{}), "invalid statement", nil) require.Error(t, err, "unrecognized statement: invalid statement'") } @@ -1525,7 +1527,7 @@ func TestExecutorDeniedErrorNoBuffer(t *testing.T) { vschemaWaitTimeout = 500 * time.Millisecond - session := NewAutocommitSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{TargetString: "@primary"}) startExec := time.Now() _, err := executor.Execute(ctx, nil, "TestExecutorDeniedErrorNoBuffer", session, "select * from user", nil) require.NoError(t, err, "enforce denied tables not buffered") @@ -1559,9 +1561,8 @@ var pv = querypb.ExecuteOptions_Gen4 func TestGetPlanUnnormalized(t *testing.T) { r, _, _, _, ctx := createExecutorEnv(t) - - emptyvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) - unshardedvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + emptyvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, econtext.VCursorConfig{}) + unshardedvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, econtext.VCursorConfig{}) query1 := "select * from music_user_map where id = 1" plan1, logStats1 := getPlanCached(t, ctx, r, emptyvc, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) @@ -1604,7 +1605,7 @@ func assertCacheSize(t *testing.T, c *PlanCache, expected int) { } } -func assertCacheContains(t *testing.T, e *Executor, vc *vcursorImpl, sql string) *engine.Plan { +func assertCacheContains(t *testing.T, e *Executor, vc *econtext.VCursorImpl, sql string) *engine.Plan { t.Helper() var plan *engine.Plan @@ -1623,9 +1624,9 @@ func assertCacheContains(t *testing.T, e *Executor, vc *vcursorImpl, sql string) return plan } -func getPlanCached(t *testing.T, ctx context.Context, e *Executor, vcursor *vcursorImpl, sql string, comments sqlparser.MarginComments, bindVars map[string]*querypb.BindVariable, skipQueryPlanCache bool) (*engine.Plan, *logstats.LogStats) { +func getPlanCached(t *testing.T, ctx context.Context, e *Executor, vcursor *econtext.VCursorImpl, sql string, comments sqlparser.MarginComments, bindVars map[string]*querypb.BindVariable, skipQueryPlanCache bool) (*engine.Plan, *logstats.LogStats) { logStats := logstats.NewLogStats(ctx, "Test", "", "", nil) - vcursor.safeSession = &SafeSession{ + vcursor.SafeSession = &econtext.SafeSession{ Session: &vtgatepb.Session{ Options: &querypb.ExecuteOptions{SkipQueryPlanCache: skipQueryPlanCache}}, } @@ -1644,7 +1645,7 @@ func TestGetPlanCacheUnnormalized(t *testing.T) { t.Run("Cache", func(t *testing.T) { r, _, _, _, ctx := createExecutorEnv(t) - emptyvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + emptyvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, econtext.VCursorConfig{}) query1 := "select * from music_user_map where id = 1" _, logStats1 := getPlanCached(t, ctx, r, emptyvc, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, true) @@ -1668,7 +1669,7 @@ func TestGetPlanCacheUnnormalized(t *testing.T) { // Skip cache using directive r, _, _, _, ctx := createExecutorEnv(t) - unshardedvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + unshardedvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) query1 := "insert /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ into user(id) values (1), (2)" getPlanCached(t, ctx, r, unshardedvc, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) @@ -1679,12 +1680,12 @@ func TestGetPlanCacheUnnormalized(t *testing.T) { assertCacheSize(t, r.plans, 1) // the target string will be resolved and become part of the plan cache key, which adds a new entry - ksIDVc1, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[deadbeef]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + ksIDVc1, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[deadbeef]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) getPlanCached(t, ctx, r, ksIDVc1, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) assertCacheSize(t, r.plans, 2) // the target string will be resolved and become part of the plan cache key, as it's an unsharded ks, it will be the same entry as above - ksIDVc2, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[beefdead]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + ksIDVc2, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[beefdead]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) getPlanCached(t, ctx, r, ksIDVc2, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) assertCacheSize(t, r.plans, 2) }) @@ -1694,7 +1695,7 @@ func TestGetPlanCacheNormalized(t *testing.T) { t.Run("Cache", func(t *testing.T) { r, _, _, _, ctx := createExecutorEnv(t) r.normalize = true - emptyvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + emptyvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) query1 := "select * from music_user_map where id = 1" _, logStats1 := getPlanCached(t, ctx, r, emptyvc, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, true /* skipQueryPlanCache */) @@ -1711,7 +1712,7 @@ func TestGetPlanCacheNormalized(t *testing.T) { // Skip cache using directive r, _, _, _, ctx := createExecutorEnv(t) r.normalize = true - unshardedvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + unshardedvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) query1 := "insert /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ into user(id) values (1), (2)" getPlanCached(t, ctx, r, unshardedvc, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) @@ -1722,12 +1723,12 @@ func TestGetPlanCacheNormalized(t *testing.T) { assertCacheSize(t, r.plans, 1) // the target string will be resolved and become part of the plan cache key, which adds a new entry - ksIDVc1, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[deadbeef]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + ksIDVc1, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[deadbeef]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) getPlanCached(t, ctx, r, ksIDVc1, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) assertCacheSize(t, r.plans, 2) // the target string will be resolved and become part of the plan cache key, as it's an unsharded ks, it will be the same entry as above - ksIDVc2, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[beefdead]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + ksIDVc2, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "[beefdead]"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, r.vConfig) getPlanCached(t, ctx, r, ksIDVc2, query1, makeComments(" /* comment */"), map[string]*querypb.BindVariable{}, false) assertCacheSize(t, r.plans, 2) }) @@ -1737,8 +1738,8 @@ func TestGetPlanNormalized(t *testing.T) { r, _, _, _, ctx := createExecutorEnv(t) r.normalize = true - emptyvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) - unshardedvc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + emptyvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, econtext.VCursorConfig{}) + unshardedvc, _ := econtext.NewVCursorImpl(econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded + "@unknown"}), makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, econtext.VCursorConfig{}) query1 := "select * from music_user_map where id = 1" query2 := "select * from music_user_map where id = 2" @@ -1785,7 +1786,7 @@ func TestGetPlanPriority(t *testing.T) { {name: "empty priority", sql: "select * from music_user_map", expectedPriority: "", expectedError: nil}, } - session := NewSafeSession(&vtgatepb.Session{TargetString: "@unknown", Options: &querypb.ExecuteOptions{}}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@unknown", Options: &querypb.ExecuteOptions{}}) for _, aTestCase := range testCases { testCase := aTestCase @@ -1795,7 +1796,7 @@ func TestGetPlanPriority(t *testing.T) { r.normalize = true logStats := logstats.NewLogStats(ctx, "Test", "", "", nil) - vCursor, err := newVCursorImpl(session, makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, false, pv) + vCursor, err := econtext.NewVCursorImpl(session, makeComments(""), r, nil, r.vm, r.VSchema(), r.resolver.resolver, nil, nullResultsObserver{}, econtext.VCursorConfig{}) assert.NoError(t, err) stmt, err := sqlparser.NewTestParser().Parse(testCase.sql) @@ -1809,7 +1810,7 @@ func TestGetPlanPriority(t *testing.T) { } else { assert.NoError(t, err) assert.Equal(t, testCase.expectedPriority, priorityFromStatement) - assert.Equal(t, testCase.expectedPriority, vCursor.safeSession.Options.Priority) + assert.Equal(t, testCase.expectedPriority, vCursor.SafeSession.Options.Priority) } }) } @@ -1966,7 +1967,7 @@ func TestExecutorMaxPayloadSizeExceeded(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) warningCount := warnings.Counts()["WarnPayloadSizeExceeded"] testMaxPayloadSizeExceeded := []string{ "select * from main1", @@ -2014,7 +2015,7 @@ func TestOlapSelectDatabase(t *testing.T) { cbInvoked = true return nil } - err := executor.StreamExecute(context.Background(), nil, "TestExecute", NewSafeSession(session), sql, nil, cb) + err := executor.StreamExecute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, nil, cb) assert.NoError(t, err) assert.True(t, cbInvoked) } @@ -2022,7 +2023,7 @@ func TestOlapSelectDatabase(t *testing.T) { func TestExecutorClearsWarnings(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{ + session := econtext.NewSafeSession(&vtgatepb.Session{ Warnings: []*querypb.QueryWarning{{Code: 234, Message: "oh noes"}}, }) _, err := executor.Execute(context.Background(), nil, "TestExecute", session, "select 42", nil) @@ -2039,7 +2040,6 @@ func TestServingKeyspaces(t *testing.T) { executor, sbc1, _, sbclookup, ctx := createExecutorEnv(t) - executor.pv = querypb.ExecuteOptions_Gen4 gw, ok := executor.resolver.resolver.GetGateway().(*TabletGateway) require.True(t, ok) hc := gw.hc.(*discovery.FakeHealthCheck) @@ -2058,7 +2058,7 @@ func TestServingKeyspaces(t *testing.T) { }) require.ElementsMatch(t, []string{"TestExecutor", "TestUnsharded"}, gw.GetServingKeyspaces()) - result, err := executor.Execute(ctx, nil, "TestServingKeyspaces", NewSafeSession(&vtgatepb.Session{}), "select keyspace_name from dual", nil) + result, err := executor.Execute(ctx, nil, "TestServingKeyspaces", econtext.NewSafeSession(&vtgatepb.Session{}), "select keyspace_name from dual", nil) require.NoError(t, err) require.Equal(t, `[[VARCHAR("TestExecutor")]]`, fmt.Sprintf("%v", result.Rows)) @@ -2074,7 +2074,7 @@ func TestServingKeyspaces(t *testing.T) { // Clear plan cache, to force re-planning of the query. executor.ClearPlans() require.ElementsMatch(t, []string{"TestUnsharded"}, gw.GetServingKeyspaces()) - result, err = executor.Execute(ctx, nil, "TestServingKeyspaces", NewSafeSession(&vtgatepb.Session{}), "select keyspace_name from dual", nil) + result, err = executor.Execute(ctx, nil, "TestServingKeyspaces", econtext.NewSafeSession(&vtgatepb.Session{}), "select keyspace_name from dual", nil) require.NoError(t, err) require.Equal(t, `[[VARCHAR("TestUnsharded")]]`, fmt.Sprintf("%v", result.Rows)) } @@ -2150,9 +2150,9 @@ func TestExecutorOther(t *testing.T) { sbc2.ExecCount.Store(0) sbclookup.ExecCount.Store(0) - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) if tc.hasNoKeyspaceErr { - assert.Error(t, err, errNoKeyspace) + assert.Error(t, err, econtext.ErrNoKeyspace.Error()) } else if tc.hasDestinationShardErr { assert.Errorf(t, err, "Destination can only be a single shard for statement: %s", stmt) } else { @@ -2206,7 +2206,7 @@ func TestExecutorAnalyze(t *testing.T) { sbc2.ExecCount.Store(0) sbclookup.ExecCount.Store(0) - _, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + _, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) require.NoError(t, err) utils.MustMatch(t, tc.wantCnts, cnts{ @@ -2270,7 +2270,7 @@ func TestExecutorExplainStmt(t *testing.T) { sbc2.ExecCount.Store(0) sbclookup.ExecCount.Store(0) - _, err := executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + _, err := executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) assert.NoError(t, err) utils.MustMatch(t, tc.wantCnts, cnts{ @@ -2360,9 +2360,9 @@ func TestExecutorOtherAdmin(t *testing.T) { sbc2.ExecCount.Store(0) sbclookup.ExecCount.Store(0) - _, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) + _, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), stmt, nil) if tc.hasNoKeyspaceErr { - assert.Error(t, err, errNoKeyspace) + assert.Error(t, err, econtext.ErrNoKeyspace.Error()) } else if tc.hasDestinationShardErr { assert.Errorf(t, err, "Destination can only be a single shard for statement: %s, got: DestinationExactKeyRange(-)", stmt) } else { @@ -2387,7 +2387,7 @@ func TestExecutorSavepointInTx(t *testing.T) { logChan := executor.queryLogger.Subscribe("TestExecutorSavepoint") defer executor.queryLogger.Unsubscribe(logChan) - session := NewSafeSession(&vtgatepb.Session{Autocommit: false, TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{Autocommit: false, TargetString: "@primary"}) _, err := exec(executor, session, "savepoint a") require.NoError(t, err) _, err = exec(executor, session, "rollback to a") @@ -2470,7 +2470,7 @@ func TestExecutorSavepointInTxWithReservedConn(t *testing.T) { logChan := executor.queryLogger.Subscribe("TestExecutorSavepoint") defer executor.queryLogger.Unsubscribe(logChan) - session := NewSafeSession(&vtgatepb.Session{Autocommit: true, TargetString: "TestExecutor", EnableSystemSettings: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{Autocommit: true, TargetString: "TestExecutor", EnableSystemSettings: true}) sbc1.SetResults([]*sqltypes.Result{ sqltypes.MakeTestResult(sqltypes.MakeTestFields("orig|new", "varchar|varchar"), "a|"), }) @@ -2537,7 +2537,7 @@ func TestExecutorSavepointWithoutTx(t *testing.T) { logChan := executor.queryLogger.Subscribe("TestExecutorSavepoint") defer executor.queryLogger.Unsubscribe(logChan) - session := NewSafeSession(&vtgatepb.Session{Autocommit: true, TargetString: "@primary", InTransaction: false}) + session := econtext.NewSafeSession(&vtgatepb.Session{Autocommit: true, TargetString: "@primary", InTransaction: false}) _, err := exec(executor, session, "savepoint a") require.NoError(t, err) _, err = exec(executor, session, "rollback to a") @@ -2622,9 +2622,9 @@ func TestExecutorCallProc(t *testing.T) { sbc2.ExecCount.Store(0) sbcUnsharded.ExecCount.Store(0) - _, err := executor.Execute(context.Background(), nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), "CALL proc()", nil) + _, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), "CALL proc()", nil) if tc.hasNoKeyspaceErr { - assert.EqualError(t, err, errNoKeyspace.Error()) + assert.EqualError(t, err, econtext.ErrNoKeyspace.Error()) } else if tc.unshardedOnlyErr { require.EqualError(t, err, "CALL is not supported for sharded keyspace") } else { @@ -2644,9 +2644,9 @@ func TestExecutorTempTable(t *testing.T) { executor, _, _, sbcUnsharded, ctx := createExecutorEnv(t) initialWarningsCount := warnings.Counts()["WarnUnshardedOnly"] - executor.warnShardedOnly = true + executor.vConfig.WarnShardedOnly = true creatQuery := "create temporary table temp_t(id bigint primary key)" - session := NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}) _, err := executor.Execute(ctx, nil, "TestExecutorTempTable", session, creatQuery, nil) require.NoError(t, err) assert.EqualValues(t, 1, sbcUnsharded.ExecCount.Load()) @@ -2665,7 +2665,7 @@ func TestExecutorShowVitessMigrations(t *testing.T) { executor, sbc1, sbc2, _, ctx := createExecutorEnv(t) showQuery := "show vitess_migrations" - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) _, err := executor.Execute(ctx, nil, "", session, showQuery, nil) require.NoError(t, err) assert.Contains(t, sbc1.StringQueries(), "show vitess_migrations") @@ -2676,7 +2676,7 @@ func TestExecutorDescHash(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) showQuery := "desc hash_index" - session := NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) _, err := executor.Execute(ctx, nil, "", session, showQuery, nil) require.NoError(t, err) } @@ -2684,7 +2684,7 @@ func TestExecutorDescHash(t *testing.T) { func TestExecutorVExplainQueries(t *testing.T) { executor, _, _, sbclookup, ctx := createExecutorEnv(t) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) sbclookup.SetResults([]*sqltypes.Result{ sqltypes.MakeTestResult(sqltypes.MakeTestFields("name|user_id", "varchar|int64"), "apa|1", "apa|2"), @@ -2697,7 +2697,7 @@ func TestExecutorVExplainQueries(t *testing.T) { // Test the streaming side as well var results []sqltypes.Row - session = NewAutocommitSession(&vtgatepb.Session{}) + session = econtext.NewAutocommitSession(&vtgatepb.Session{}) err = executor.StreamExecute(ctx, nil, "TestExecutorVExplainQueries", session, "vexplain queries select * from user where name = 'apa'", nil, func(result *sqltypes.Result) error { results = append(results, result.Rows...) return nil @@ -2710,7 +2710,7 @@ func TestExecutorVExplainQueries(t *testing.T) { func TestExecutorStartTxnStmt(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) tcases := []struct { beginSQL string @@ -2757,7 +2757,7 @@ func TestExecutorPrepareExecute(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) executor.normalize = true - session := NewAutocommitSession(&vtgatepb.Session{}) + session := econtext.NewAutocommitSession(&vtgatepb.Session{}) // prepare statement. _, err := executor.Execute(context.Background(), nil, "TestExecutorPrepareExecute", session, "prepare prep_user from 'select * from user where id = ?'", nil) @@ -2834,7 +2834,7 @@ func TestExecutorSettingsInTwoPC(t *testing.T) { sbc2.SetResults(tcase.testRes) // create a new session - session := NewSafeSession(&vtgatepb.Session{ + session := econtext.NewSafeSession(&vtgatepb.Session{ TargetString: KsTestSharded, TransactionMode: vtgatepb.TransactionMode_TWOPC, EnableSystemSettings: true, @@ -2892,7 +2892,7 @@ func TestExecutorRejectTwoPC(t *testing.T) { sbc2.SetResults(tcase.testRes) // create a new session - session := NewSafeSession(&vtgatepb.Session{ + session := econtext.NewSafeSession(&vtgatepb.Session{ TargetString: KsTestSharded, TransactionMode: vtgatepb.TransactionMode_TWOPC, EnableSystemSettings: true, @@ -2922,7 +2922,7 @@ func TestExecutorTruncateErrors(t *testing.T) { truncateErrorLen = 32 defer func() { truncateErrorLen = save }() - session := NewSafeSession(&vtgatepb.Session{}) + session := econtext.NewSafeSession(&vtgatepb.Session{}) fn := func(r *sqltypes.Result) error { return nil } @@ -2982,7 +2982,7 @@ func TestExecutorFlushStmt(t *testing.T) { for _, tc := range tcs { t.Run(tc.query+tc.targetStr, func(t *testing.T) { - _, err := executor.Execute(context.Background(), nil, "TestExecutorFlushStmt", NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), tc.query, nil) + _, err := executor.Execute(context.Background(), nil, "TestExecutorFlushStmt", econtext.NewSafeSession(&vtgatepb.Session{TargetString: tc.targetStr}), tc.query, nil) if tc.expectedErr == "" { require.NoError(t, err) } else { @@ -3029,7 +3029,7 @@ func TestExecutorKillStmt(t *testing.T) { allowKillStmt = !tc.disallow t.Run("execute:"+tc.query+tc.errStr, func(t *testing.T) { mysqlCtx := &fakeMysqlConnection{ErrMsg: tc.errStr} - _, err := executor.Execute(context.Background(), mysqlCtx, "TestExecutorKillStmt", NewAutocommitSession(&vtgatepb.Session{}), tc.query, nil) + _, err := executor.Execute(context.Background(), mysqlCtx, "TestExecutorKillStmt", econtext.NewAutocommitSession(&vtgatepb.Session{}), tc.query, nil) if tc.errStr != "" { require.ErrorContains(t, err, tc.errStr) } else { @@ -3039,7 +3039,7 @@ func TestExecutorKillStmt(t *testing.T) { }) t.Run("stream:"+tc.query+tc.errStr, func(t *testing.T) { mysqlCtx := &fakeMysqlConnection{ErrMsg: tc.errStr} - err := executor.StreamExecute(context.Background(), mysqlCtx, "TestExecutorKillStmt", NewAutocommitSession(&vtgatepb.Session{}), tc.query, nil, func(result *sqltypes.Result) error { + err := executor.StreamExecute(context.Background(), mysqlCtx, "TestExecutorKillStmt", econtext.NewAutocommitSession(&vtgatepb.Session{}), tc.query, nil, func(result *sqltypes.Result) error { return nil }) if tc.errStr != "" { @@ -3075,7 +3075,7 @@ func (f *fakeMysqlConnection) KillConnection(ctx context.Context, connID uint32) var _ vtgateservice.MySQLConnection = (*fakeMysqlConnection)(nil) -func exec(executor *Executor, session *SafeSession, sql string) (*sqltypes.Result, error) { +func exec(executor *Executor, session *econtext.SafeSession, sql string) (*sqltypes.Result, error) { return executor.Execute(context.Background(), nil, "TestExecute", session, sql, nil) } diff --git a/go/vt/vtgate/executor_vexplain_test.go b/go/vt/vtgate/executor_vexplain_test.go index 99eb77c7ed4..a9516492f1b 100644 --- a/go/vt/vtgate/executor_vexplain_test.go +++ b/go/vt/vtgate/executor_vexplain_test.go @@ -26,6 +26,8 @@ import ( "github.com/stretchr/testify/assert" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -135,7 +137,7 @@ func TestVExplainKeys(t *testing.T) { for _, tt := range tests { t.Run(tt.Query, func(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.Query, nil) require.NoError(t, err) diff --git a/go/vt/vtgate/executor_vschema_ddl_test.go b/go/vt/vtgate/executor_vschema_ddl_test.go index 1c912ed0d62..825b65ab8f3 100644 --- a/go/vt/vtgate/executor_vschema_ddl_test.go +++ b/go/vt/vtgate/executor_vschema_ddl_test.go @@ -25,6 +25,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/callerid" @@ -138,7 +140,7 @@ func TestPlanExecutorAlterVSchemaKeyspace(t *testing.T) { vschemaacl.AuthorizedDDLUsers = "" }() executor, _, _, _, ctx := createExecutorEnv(t) - session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) vschemaUpdates := make(chan *vschemapb.SrvVSchema, 2) executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { @@ -180,7 +182,7 @@ func TestPlanExecutorCreateVindexDDL(t *testing.T) { t.Fatalf("test_vindex should not exist in original vschema") } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema create vindex test_vindex using hash" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) require.NoError(t, err) @@ -222,7 +224,7 @@ func TestPlanExecutorDropVindexDDL(t *testing.T) { t.Fatalf("test_vindex should not exist in original vschema") } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema drop vindex test_vindex" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) wantErr := "vindex test_vindex does not exists in keyspace TestExecutor" @@ -296,7 +298,7 @@ func TestPlanExecutorAddDropVschemaTableDDL(t *testing.T) { vschemaTables = append(vschemaTables, t) } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema add table test_table" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) require.NoError(t, err) @@ -308,7 +310,7 @@ func TestPlanExecutorAddDropVschemaTableDDL(t *testing.T) { _ = waitForVschemaTables(t, ks, append([]string{"test_table", "test_table2"}, vschemaTables...), executor) // Should fail adding a table on a sharded keyspace - session = NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}) stmt = "alter vschema add table test_table" _, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) wantErr := "add vschema table: unsupported on sharded keyspace TestExecutor" @@ -343,7 +345,7 @@ func TestExecutorAddSequenceDDL(t *testing.T) { vschemaTables = append(vschemaTables, t) } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema add sequence test_seq" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) require.NoError(t, err) @@ -357,7 +359,7 @@ func TestExecutorAddSequenceDDL(t *testing.T) { // Should fail adding a table on a sharded keyspace ksSharded := "TestExecutor" - session = NewSafeSession(&vtgatepb.Session{TargetString: ksSharded}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: ksSharded}) stmt = "alter vschema add sequence sequence_table" _, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) @@ -403,7 +405,7 @@ func TestExecutorDropSequenceDDL(t *testing.T) { t.Fatalf("test_seq should not exist in original vschema") } - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) // add test sequence stmt := "alter vschema add sequence test_seq" @@ -428,7 +430,7 @@ func TestExecutorDropSequenceDDL(t *testing.T) { } // Should fail dropping a non-existing test sequence - session = NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session = econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt = "alter vschema drop sequence test_seq" _, err = executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) @@ -447,7 +449,7 @@ func TestExecutorDropAutoIncDDL(t *testing.T) { executor, _, _, _, ctx := createExecutorEnv(t) ks := KsTestUnsharded - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) stmt := "alter vschema add table test_table" _, err := executor.Execute(ctx, nil, "TestExecute", session, stmt, nil) @@ -488,7 +490,7 @@ func TestExecutorAddDropVindexDDL(t *testing.T) { }() executor, sbc1, sbc2, sbclookup, ctx := createExecutorEnv(t) ks := "TestExecutor" - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema @@ -706,7 +708,7 @@ func TestExecutorAddDropVindexDDL(t *testing.T) { require.EqualError(t, err, "table TestExecutor.nonexistent not defined in vschema") stmt = "alter vschema on nonexistent drop vindex test_lookup" - _, err = executor.Execute(ctx, nil, "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: "InvalidKeyspace"}), stmt, nil) + _, err = executor.Execute(ctx, nil, "TestExecute", econtext.NewSafeSession(&vtgatepb.Session{TargetString: "InvalidKeyspace"}), stmt, nil) require.EqualError(t, err, "VT05003: unknown database 'InvalidKeyspace' in vschema") stmt = "alter vschema on nowhere.nohow drop vindex test_lookup" @@ -731,7 +733,7 @@ func TestPlanExecutorVindexDDLACL(t *testing.T) { // t.Skip("not yet planned") executor, _, _, _, ctx := createExecutorEnv(t) ks := "TestExecutor" - session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: ks}) ctxRedUser := callerid.NewContext(ctx, &vtrpcpb.CallerID{}, &querypb.VTGateCallerID{Username: "redUser"}) ctxBlueUser := callerid.NewContext(ctx, &vtrpcpb.CallerID{}, &querypb.VTGateCallerID{Username: "blueUser"}) diff --git a/go/vt/vtgate/executor_vstream_test.go b/go/vt/vtgate/executor_vstream_test.go index 5466e9e8f3f..22fb7ee1034 100644 --- a/go/vt/vtgate/executor_vstream_test.go +++ b/go/vt/vtgate/executor_vstream_test.go @@ -21,6 +21,7 @@ import ( "time" "vitess.io/vitess/go/vt/vtgate/engine" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" querypb "vitess.io/vitess/go/vt/proto/query" @@ -76,7 +77,7 @@ func TestVStreamSQLUnsharded(t *testing.T) { results := make(chan *sqltypes.Result, 20) go func() { - err := executor.StreamExecute(ctx, nil, "TestExecuteStream", NewAutocommitSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), sql, nil, func(qr *sqltypes.Result) error { + err := executor.StreamExecute(ctx, nil, "TestExecuteStream", econtext.NewAutocommitSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), sql, nil, func(qr *sqltypes.Result) error { results <- qr return nil }) diff --git a/go/vt/vtgate/executorcontext/faketopo.go b/go/vt/vtgate/executorcontext/faketopo.go new file mode 100644 index 00000000000..f61119dce15 --- /dev/null +++ b/go/vt/vtgate/executorcontext/faketopo.go @@ -0,0 +1,68 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package executorcontext + +import ( + "context" + "encoding/hex" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/topo" +) + +type FakeTopoServer struct{} + +// GetTopoServer returns the full topo.Server instance. +func (f *FakeTopoServer) GetTopoServer() (*topo.Server, error) { + return nil, nil +} + +// GetSrvKeyspaceNames returns the list of keyspaces served in +// the provided cell. +func (f *FakeTopoServer) GetSrvKeyspaceNames(ctx context.Context, cell string, staleOK bool) ([]string, error) { + return []string{"ks1"}, nil +} + +// GetSrvKeyspace returns the SrvKeyspace for a cell/keyspace. +func (f *FakeTopoServer) GetSrvKeyspace(ctx context.Context, cell, keyspace string) (*topodatapb.SrvKeyspace, error) { + zeroHexBytes, _ := hex.DecodeString("") + eightyHexBytes, _ := hex.DecodeString("80") + ks := &topodatapb.SrvKeyspace{ + Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{ + { + ServedType: topodatapb.TabletType_PRIMARY, + ShardReferences: []*topodatapb.ShardReference{ + {Name: "-80", KeyRange: &topodatapb.KeyRange{Start: zeroHexBytes, End: eightyHexBytes}}, + {Name: "80-", KeyRange: &topodatapb.KeyRange{Start: eightyHexBytes, End: zeroHexBytes}}, + }, + }, + }, + } + return ks, nil +} + +func (f *FakeTopoServer) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodatapb.SrvKeyspace, error) bool) { + ks, err := f.GetSrvKeyspace(ctx, cell, keyspace) + callback(ks, err) +} + +// WatchSrvVSchema starts watching the SrvVSchema object for +// the provided cell. It will call the callback when +// a new value or an error occurs. +func (f *FakeTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { +} diff --git a/go/vt/vtgate/safe_session.go b/go/vt/vtgate/executorcontext/safe_session.go similarity index 89% rename from go/vt/vtgate/safe_session.go rename to go/vt/vtgate/executorcontext/safe_session.go index ec9229075e3..c77bba76ff8 100644 --- a/go/vt/vtgate/safe_session.go +++ b/go/vt/vtgate/executorcontext/safe_session.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package vtgate +package executorcontext import ( "fmt" @@ -55,7 +55,7 @@ type ( rollbackOnPartialExec string savepointName string - // this is a signal that found_rows has already been handles by the primitives, + // this is a signal that found_rows has already been handled by the primitives, // and doesn't have to be updated by the executor foundRowsHandled bool @@ -63,12 +63,12 @@ type ( // as the query that started a new transaction on the shard belong to a vindex. queryFromVindex bool - logging *executeLogger + logging *ExecuteLogger *vtgatepb.Session } - executeLogger struct { + ExecuteLogger struct { mu sync.Mutex entries []engine.ExecuteEntry lastID int @@ -124,6 +124,8 @@ const ( savepointRollback ) +const TxRollback = "Rollback Transaction" + // NewSafeSession returns a new SafeSession based on the Session func NewSafeSession(sessn *vtgatepb.Session) *SafeSession { if sessn == nil { @@ -202,6 +204,50 @@ func (session *SafeSession) resetCommonLocked() { } } +// NewAutocommitSession returns a SafeSession based on the original +// session, but with autocommit enabled. +func (session *SafeSession) NewAutocommitSession() *SafeSession { + ss := NewAutocommitSession(session.Session) + ss.logging = session.logging + return ss +} + +// IsFoundRowsHandled returns the foundRowsHandled. +func (session *SafeSession) IsFoundRowsHandled() bool { + session.mu.Lock() + defer session.mu.Unlock() + return session.foundRowsHandled +} + +// SetFoundRows set the found rows value. +func (session *SafeSession) SetFoundRows(value uint64) { + session.mu.Lock() + defer session.mu.Unlock() + session.FoundRows = value + session.foundRowsHandled = true +} + +// GetRollbackOnPartialExec returns the rollbackOnPartialExec value. +func (session *SafeSession) GetRollbackOnPartialExec() string { + session.mu.Lock() + defer session.mu.Unlock() + return session.rollbackOnPartialExec +} + +// SetQueryFromVindex set the queryFromVindex value. +func (session *SafeSession) SetQueryFromVindex(value bool) { + session.mu.Lock() + defer session.mu.Unlock() + session.queryFromVindex = value +} + +// GetQueryFromVindex returns the queryFromVindex value. +func (session *SafeSession) GetQueryFromVindex() bool { + session.mu.Lock() + defer session.mu.Unlock() + return session.queryFromVindex +} + // SetQueryTimeout sets the query timeout func (session *SafeSession) SetQueryTimeout(queryTimeout int64) { session.mu.Lock() @@ -309,7 +355,7 @@ func (session *SafeSession) SetRollbackCommand() { if session.savepointState == savepointSet { session.rollbackOnPartialExec = fmt.Sprintf("rollback to %s", session.savepointName) } else { - session.rollbackOnPartialExec = txRollback + session.rollbackOnPartialExec = TxRollback } session.savepointState = savepointRollbackSet } @@ -337,6 +383,18 @@ func (session *SafeSession) SetCommitOrder(co vtgatepb.CommitOrder) { session.commitOrder = co } +// GetCommitOrder returns the commit order. +func (session *SafeSession) GetCommitOrder() vtgatepb.CommitOrder { + session.mu.Lock() + defer session.mu.Unlock() + return session.commitOrder +} + +// GetLogger returns executor logger. +func (session *SafeSession) GetLogger() *ExecuteLogger { + return session.logging +} + // InTransaction returns true if we are in a transaction func (session *SafeSession) InTransaction() bool { session.mu.Lock() @@ -410,15 +468,22 @@ func (session *SafeSession) findSessionLocked(keyspace, shard string, tabletType return nil } +type ShardActionInfo interface { + TransactionID() int64 + ReservedID() int64 + RowsAffected() bool + Alias() *topodatapb.TabletAlias +} + // AppendOrUpdate adds a new ShardSession, or updates an existing one if one already exists for the given shard session -func (session *SafeSession) AppendOrUpdate(target *querypb.Target, info *shardActionInfo, existingSession *vtgatepb.Session_ShardSession, txMode vtgatepb.TransactionMode) error { +func (session *SafeSession) AppendOrUpdate(target *querypb.Target, info ShardActionInfo, existingSession *vtgatepb.Session_ShardSession, txMode vtgatepb.TransactionMode) error { session.mu.Lock() defer session.mu.Unlock() // additional check of transaction id is required // as now in autocommit mode there can be session due to reserved connection // that needs to be stored as shard session. - if session.autocommitState == autocommitted && info.transactionID != 0 { + if session.autocommitState == autocommitted && info.TransactionID() != 0 { // Should be unreachable return vterrors.VT13001("unexpected 'autocommitted' state in transaction") } @@ -429,10 +494,10 @@ func (session *SafeSession) AppendOrUpdate(target *querypb.Target, info *shardAc session.autocommitState = notAutocommittable if existingSession != nil { - existingSession.TransactionId = info.transactionID - existingSession.ReservedId = info.reservedID + existingSession.TransactionId = info.TransactionID() + existingSession.ReservedId = info.ReservedID() if !existingSession.RowsAffected { - existingSession.RowsAffected = info.rowsAffected + existingSession.RowsAffected = info.RowsAffected() } if existingSession.VindexOnly { existingSession.VindexOnly = session.queryFromVindex @@ -444,10 +509,10 @@ func (session *SafeSession) AppendOrUpdate(target *querypb.Target, info *shardAc } newSession := &vtgatepb.Session_ShardSession{ Target: target, - TabletAlias: info.alias, - TransactionId: info.transactionID, - ReservedId: info.reservedID, - RowsAffected: info.rowsAffected, + TabletAlias: info.Alias(), + TransactionId: info.TransactionID(), + ReservedId: info.ReservedID(), + RowsAffected: info.RowsAffected(), VindexOnly: session.queryFromVindex, } @@ -700,12 +765,11 @@ func (session *SafeSession) UpdateLockHeartbeat() { session.LastLockHeartbeat = time.Now().Unix() } -// TriggerLockHeartBeat returns if it time to trigger next lock heartbeat -func (session *SafeSession) TriggerLockHeartBeat() bool { +// GetLockHeartbeat returns last time the lock heartbeat was sent. +func (session *SafeSession) GetLockHeartbeat() int64 { session.mu.Lock() defer session.mu.Unlock() - now := time.Now().Unix() - return now-session.LastLockHeartbeat >= int64(lockHeartbeatTime.Seconds()) + return session.LastLockHeartbeat } // InLockSession returns whether locking is used on this session. @@ -858,9 +922,7 @@ func (session *SafeSession) GetOrCreateOptions() *querypb.ExecuteOptions { return session.Session.Options } -var _ iQueryOption = (*SafeSession)(nil) - -func (session *SafeSession) cachePlan() bool { +func (session *SafeSession) CachePlan() bool { if session == nil || session.Options == nil { return true } @@ -871,7 +933,7 @@ func (session *SafeSession) cachePlan() bool { return !(session.Options.SkipQueryPlanCache || session.Options.HasCreatedTempTables) } -func (session *SafeSession) getSelectLimit() int { +func (session *SafeSession) GetSelectLimit() int { if session == nil || session.Options == nil { return -1 } @@ -882,16 +944,16 @@ func (session *SafeSession) getSelectLimit() int { return int(session.Options.SqlSelectLimit) } -// isTxOpen returns true if there is open connection to any of the shard. -func (session *SafeSession) isTxOpen() bool { +// IsTxOpen returns true if there is open connection to any of the shard. +func (session *SafeSession) IsTxOpen() bool { session.mu.Lock() defer session.mu.Unlock() return len(session.ShardSessions) > 0 || len(session.PreSessions) > 0 || len(session.PostSessions) > 0 } -// getSessions returns the shard session for the current commit order. -func (session *SafeSession) getSessions() []*vtgatepb.Session_ShardSession { +// GetSessions returns the shard session for the current commit order. +func (session *SafeSession) GetSessions() []*vtgatepb.Session_ShardSession { session.mu.Lock() defer session.mu.Unlock() @@ -978,7 +1040,7 @@ func (session *SafeSession) EnableLogging(parser *sqlparser.Parser) { session.mu.Lock() defer session.mu.Unlock() - session.logging = &executeLogger{ + session.logging = &ExecuteLogger{ parser: parser, } } @@ -1016,7 +1078,15 @@ func (session *SafeSession) GetPrepareData(name string) *vtgatepb.PrepareData { return session.PrepareStatement[name] } -func (l *executeLogger) log(primitive engine.Primitive, target *querypb.Target, gateway srvtopo.Gateway, query string, begin bool, bv map[string]*querypb.BindVariable) { +func (session *SafeSession) Log(primitive engine.Primitive, target *querypb.Target, gateway srvtopo.Gateway, query string, begin bool, bv map[string]*querypb.BindVariable) { + session.logging.Log(primitive, target, gateway, query, begin, bv) +} + +func (session *SafeSession) GetLogs() []engine.ExecuteEntry { + return session.logging.GetLogs() +} + +func (l *ExecuteLogger) Log(primitive engine.Primitive, target *querypb.Target, gateway srvtopo.Gateway, query string, begin bool, bv map[string]*querypb.BindVariable) { if l == nil { return } @@ -1055,7 +1125,10 @@ func (l *executeLogger) log(primitive engine.Primitive, target *querypb.Target, }) } -func (l *executeLogger) GetLogs() []engine.ExecuteEntry { +func (l *ExecuteLogger) GetLogs() []engine.ExecuteEntry { + if l == nil { + return nil + } l.mu.Lock() defer l.mu.Unlock() result := make([]engine.ExecuteEntry, len(l.entries)) diff --git a/go/vt/vtgate/safe_session_test.go b/go/vt/vtgate/executorcontext/safe_session_test.go similarity index 87% rename from go/vt/vtgate/safe_session_test.go rename to go/vt/vtgate/executorcontext/safe_session_test.go index 9383058e81c..14ea2ad9dac 100644 --- a/go/vt/vtgate/safe_session_test.go +++ b/go/vt/vtgate/executorcontext/safe_session_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package vtgate +package executorcontext import ( "reflect" @@ -29,6 +29,31 @@ import ( vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) +type fakeInfo struct { + transactionID int64 + alias *topodatapb.TabletAlias +} + +func (s *fakeInfo) TransactionID() int64 { + return s.transactionID +} + +func (s *fakeInfo) ReservedID() int64 { + return 0 +} + +func (s *fakeInfo) RowsAffected() bool { + return false +} + +func (s *fakeInfo) Alias() *topodatapb.TabletAlias { + return s.alias +} + +func info(txId, uid int) ShardActionInfo { + return &fakeInfo{transactionID: int64(txId), alias: &topodatapb.TabletAlias{Cell: "cell", Uid: uint32(uid)}} +} + // TestFailToMultiShardWhenSetToSingleDb tests that single db transactions fails on going multi shard. func TestFailToMultiShardWhenSetToSingleDb(t *testing.T) { session := NewSafeSession(&vtgatepb.Session{ @@ -37,13 +62,13 @@ func TestFailToMultiShardWhenSetToSingleDb(t *testing.T) { err := session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "0"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 0}}, + info(1, 0), nil, vtgatepb.TransactionMode_SINGLE) require.NoError(t, err) err = session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "1"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 1}}, + info(1, 1), nil, vtgatepb.TransactionMode_SINGLE) require.Error(t, err) @@ -59,7 +84,7 @@ func TestSingleDbUpdateToMultiShard(t *testing.T) { session.queryFromVindex = true err := session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "0"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 0}}, + info(1, 0), nil, vtgatepb.TransactionMode_SINGLE) require.NoError(t, err) @@ -68,7 +93,7 @@ func TestSingleDbUpdateToMultiShard(t *testing.T) { // shard session s1 err = session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "1"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 1}}, + info(1, 1), nil, vtgatepb.TransactionMode_SINGLE) require.NoError(t, err) @@ -76,7 +101,7 @@ func TestSingleDbUpdateToMultiShard(t *testing.T) { // shard session s0 with normal query err = session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "0"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 1}}, + info(1, 1), session.ShardSessions[0], vtgatepb.TransactionMode_SINGLE) require.Error(t, err) @@ -93,7 +118,7 @@ func TestSingleDbPreFailOnFind(t *testing.T) { session.queryFromVindex = true err := session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "0"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 0}}, + info(1, 0), nil, vtgatepb.TransactionMode_SINGLE) require.NoError(t, err) @@ -102,7 +127,7 @@ func TestSingleDbPreFailOnFind(t *testing.T) { // shard session s1 err = session.AppendOrUpdate( &querypb.Target{Keyspace: "keyspace", Shard: "1"}, - &shardActionInfo{transactionID: 1, alias: &topodatapb.TabletAlias{Cell: "cell", Uid: 1}}, + info(1, 1), nil, vtgatepb.TransactionMode_SINGLE) require.NoError(t, err) diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go similarity index 66% rename from go/vt/vtgate/vcursor_impl.go rename to go/vt/vtgate/executorcontext/vcursor_impl.go index 691b9988d9e..d6aac5cf5b0 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package vtgate +package executorcontext import ( "context" @@ -26,6 +26,7 @@ import ( "time" "github.com/google/uuid" + "golang.org/x/exp/maps" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/config" @@ -59,38 +60,62 @@ import ( ) var ( - _ engine.VCursor = (*vcursorImpl)(nil) - _ plancontext.VSchema = (*vcursorImpl)(nil) - _ iExecute = (*Executor)(nil) - _ vindexes.VCursor = (*vcursorImpl)(nil) + _ engine.VCursor = (*VCursorImpl)(nil) + _ plancontext.VSchema = (*VCursorImpl)(nil) + _ vindexes.VCursor = (*VCursorImpl)(nil) ) +var ErrNoKeyspace = vterrors.VT09005() + type ( + ResultsObserver interface { + Observe(*sqltypes.Result) + } + + VCursorConfig struct { + Collation collations.ID + + MaxMemoryRows int + EnableShardRouting bool + DefaultTabletType topodatapb.TabletType + QueryTimeout int + DBDDLPlugin string + ForeignKeyMode vschemapb.Keyspace_ForeignKeyMode + SetVarEnabled bool + EnableViews bool + WarnShardedOnly bool + PlannerVersion plancontext.PlannerVersion + + WarmingReadsPercent int + WarmingReadsTimeout time.Duration + WarmingReadsChannel chan bool + } + // vcursor_impl needs these facilities to be able to be able to execute queries for vindexes iExecute interface { Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, session *SafeSession, s string, vars map[string]*querypb.BindVariable) (*sqltypes.Result, error) - ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver resultsObserver) (qr *sqltypes.Result, errs []error) - StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, observer resultsObserver) []error + ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver ResultsObserver) (qr *sqltypes.Result, errs []error) + StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, observer ResultsObserver) []error ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) Commit(ctx context.Context, safeSession *SafeSession) error ExecuteMessageStream(ctx context.Context, rss []*srvtopo.ResolvedShard, name string, callback func(*sqltypes.Result) error) error ExecuteVStream(ctx context.Context, rss []*srvtopo.ResolvedShard, filter *binlogdatapb.Filter, gtid string, callback func(evs []*binlogdatapb.VEvent) error) error ReleaseLock(ctx context.Context, session *SafeSession) error - showVitessReplicationStatus(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) - showShards(ctx context.Context, filter *sqlparser.ShowFilter, destTabletType topodatapb.TabletType) (*sqltypes.Result, error) - showTablets(filter *sqlparser.ShowFilter) (*sqltypes.Result, error) - showVitessMetadata(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) - setVitessMetadata(ctx context.Context, name, value string) error + ShowVitessReplicationStatus(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) + ShowShards(ctx context.Context, filter *sqlparser.ShowFilter, destTabletType topodatapb.TabletType) (*sqltypes.Result, error) + ShowTablets(filter *sqlparser.ShowFilter) (*sqltypes.Result, error) + ShowVitessMetadata(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) + SetVitessMetadata(ctx context.Context, name, value string) error // TODO: remove when resolver is gone - ParseDestinationTarget(targetString string) (string, topodatapb.TabletType, key.Destination, error) VSchema() *vindexes.VSchema - planPrepareStmt(ctx context.Context, vcursor *vcursorImpl, query string) (*engine.Plan, sqlparser.Statement, error) + PlanPrepareStmt(ctx context.Context, vcursor *VCursorImpl, query string) (*engine.Plan, sqlparser.Statement, error) - environment() *vtenv.Environment + Environment() *vtenv.Environment ReadTransaction(ctx context.Context, transactionID string) (*querypb.TransactionMetadata, error) UnresolvedTransactions(ctx context.Context, targets []*querypb.Target) ([]*querypb.TransactionMetadata, error) + AddWarningCount(name string, value int64) } // VSchemaOperator is an interface to Vschema Operations @@ -99,10 +124,11 @@ type ( UpdateVSchema(ctx context.Context, ksName string, vschema *vschemapb.SrvVSchema) error } - // vcursorImpl implements the VCursor functionality used by dependent + // VCursorImpl implements the VCursor functionality used by dependent // packages to call back into VTGate. - vcursorImpl struct { - safeSession *SafeSession + VCursorImpl struct { + config VCursorConfig + SafeSession *SafeSession keyspace string tabletType topodatapb.TabletType destination key.Destination @@ -111,7 +137,6 @@ type ( resolver *srvtopo.Resolver topoServer *topo.Server logStats *logstats.LogStats - collation collations.ID // fkChecksState stores the state of foreign key checks variable. // This state is meant to be the final fk checks state after consulting the @@ -122,16 +147,11 @@ type ( vschema *vindexes.VSchema vm VSchemaOperator semTable *semantics.SemTable - warnShardedOnly bool // when using sharded only features, a warning will be warnings field queryTimeout time.Duration warnings []*querypb.QueryWarning // any warnings that are accumulated during the planning phase are stored here - pv plancontext.PlannerVersion - warmingReadsPercent int - warmingReadsChannel chan bool - - resultsObserver resultsObserver + observer ResultsObserver // this is a map of the number of rows that every primitive has returned // if this field is nil, it means that we are not logging operator traffic @@ -140,23 +160,23 @@ type ( } ) -// newVcursorImpl creates a vcursorImpl. Before creating this object, you have to separate out any marginComments that came with +// NewVCursorImpl creates a VCursorImpl. Before creating this object, you have to separate out any marginComments that came with // the query and supply it here. Trailing comments are typically sent by the application for various reasons, // including as identifying markers. So, they have to be added back to all queries that are executed // on behalf of the original query. -func newVCursorImpl( +func NewVCursorImpl( safeSession *SafeSession, marginComments sqlparser.MarginComments, - executor *Executor, + executor iExecute, logStats *logstats.LogStats, vm VSchemaOperator, vschema *vindexes.VSchema, resolver *srvtopo.Resolver, serv srvtopo.Server, - warnShardedOnly bool, - pv plancontext.PlannerVersion, -) (*vcursorImpl, error) { - keyspace, tabletType, destination, err := parseDestinationTarget(safeSession.TargetString, vschema) + observer ResultsObserver, + cfg VCursorConfig, +) (*VCursorImpl, error) { + keyspace, tabletType, destination, err := ParseDestinationTarget(safeSession.TargetString, cfg.DefaultTabletType, vschema) if err != nil { return nil, err } @@ -171,107 +191,175 @@ func newVCursorImpl( } } - // we only support collations for the new TabletGateway implementation - var connCollation collations.ID - if executor != nil { - if gw, isTabletGw := executor.resolver.resolver.GetGateway().(*TabletGateway); isTabletGw { - connCollation = gw.DefaultConnCollation() - } - } - if connCollation == collations.Unknown { - connCollation = executor.env.CollationEnv().DefaultConnectionCharset() - } - - warmingReadsPct := 0 - var warmingReadsChan chan bool - if executor != nil { - warmingReadsPct = executor.warmingReadsPercent - warmingReadsChan = executor.warmingReadsChannel - } - return &vcursorImpl{ - safeSession: safeSession, - keyspace: keyspace, - tabletType: tabletType, - destination: destination, - marginComments: marginComments, - executor: executor, - logStats: logStats, - collation: connCollation, - resolver: resolver, - vschema: vschema, - vm: vm, - topoServer: ts, - warnShardedOnly: warnShardedOnly, - pv: pv, - warmingReadsPercent: warmingReadsPct, - warmingReadsChannel: warmingReadsChan, - resultsObserver: nullResultsObserver{}, + return &VCursorImpl{ + config: cfg, + SafeSession: safeSession, + keyspace: keyspace, + tabletType: tabletType, + destination: destination, + marginComments: marginComments, + executor: executor, + logStats: logStats, + resolver: resolver, + vschema: vschema, + vm: vm, + topoServer: ts, + + observer: observer, }, nil } +func (vc *VCursorImpl) CloneForMirroring(ctx context.Context) engine.VCursor { + callerId := callerid.EffectiveCallerIDFromContext(ctx) + immediateCallerId := callerid.ImmediateCallerIDFromContext(ctx) + + clonedCtx := callerid.NewContext(ctx, callerId, immediateCallerId) + + v := &VCursorImpl{ + config: vc.config, + SafeSession: NewAutocommitSession(vc.SafeSession.Session), + keyspace: vc.keyspace, + tabletType: vc.tabletType, + destination: vc.destination, + marginComments: vc.marginComments, + executor: vc.executor, + resolver: vc.resolver, + topoServer: vc.topoServer, + logStats: &logstats.LogStats{Ctx: clonedCtx}, + ignoreMaxMemoryRows: vc.ignoreMaxMemoryRows, + vschema: vc.vschema, + vm: vc.vm, + semTable: vc.semTable, + warnings: vc.warnings, + observer: vc.observer, + } + + v.marginComments.Trailing += "/* mirror query */" + + return v +} + +func (vc *VCursorImpl) CloneForReplicaWarming(ctx context.Context) engine.VCursor { + callerId := callerid.EffectiveCallerIDFromContext(ctx) + immediateCallerId := callerid.ImmediateCallerIDFromContext(ctx) + + timedCtx, _ := context.WithTimeout(context.Background(), vc.config.WarmingReadsTimeout) // nolint + clonedCtx := callerid.NewContext(timedCtx, callerId, immediateCallerId) + + v := &VCursorImpl{ + config: vc.config, + SafeSession: NewAutocommitSession(vc.SafeSession.Session), + keyspace: vc.keyspace, + tabletType: topodatapb.TabletType_REPLICA, + destination: vc.destination, + marginComments: vc.marginComments, + executor: vc.executor, + resolver: vc.resolver, + topoServer: vc.topoServer, + logStats: &logstats.LogStats{Ctx: clonedCtx}, + + ignoreMaxMemoryRows: vc.ignoreMaxMemoryRows, + vschema: vc.vschema, + vm: vc.vm, + semTable: vc.semTable, + warnings: vc.warnings, + observer: vc.observer, + } + + v.marginComments.Trailing += "/* warming read */" + + return v +} + +func (vc *VCursorImpl) cloneWithAutocommitSession() *VCursorImpl { + safeSession := vc.SafeSession.NewAutocommitSession() + return &VCursorImpl{ + config: vc.config, + SafeSession: safeSession, + keyspace: vc.keyspace, + tabletType: vc.tabletType, + destination: vc.destination, + marginComments: vc.marginComments, + executor: vc.executor, + logStats: vc.logStats, + resolver: vc.resolver, + vschema: vc.vschema, + vm: vc.vm, + topoServer: vc.topoServer, + observer: vc.observer, + } +} + // HasSystemVariables returns whether the session has set system variables or not -func (vc *vcursorImpl) HasSystemVariables() bool { - return vc.safeSession.HasSystemVariables() +func (vc *VCursorImpl) HasSystemVariables() bool { + return vc.SafeSession.HasSystemVariables() } // GetSystemVariables takes a visitor function that will save each system variables of the session -func (vc *vcursorImpl) GetSystemVariables(f func(k string, v string)) { - vc.safeSession.GetSystemVariables(f) +func (vc *VCursorImpl) GetSystemVariables(f func(k string, v string)) { + vc.SafeSession.GetSystemVariables(f) +} + +// GetSystemVariablesCopy returns a copy of the system variables of the session. Changes to the original map will not affect the session. +func (vc *VCursorImpl) GetSystemVariablesCopy() map[string]string { + vc.SafeSession.mu.Lock() + defer vc.SafeSession.mu.Unlock() + return maps.Clone(vc.SafeSession.SystemVariables) } // ConnCollation returns the collation of this session -func (vc *vcursorImpl) ConnCollation() collations.ID { - return vc.collation +func (vc *VCursorImpl) ConnCollation() collations.ID { + return vc.config.Collation } // Environment returns the vtenv associated with this session -func (vc *vcursorImpl) Environment() *vtenv.Environment { - return vc.executor.environment() +func (vc *VCursorImpl) Environment() *vtenv.Environment { + return vc.executor.Environment() } -func (vc *vcursorImpl) TimeZone() *time.Location { - return vc.safeSession.TimeZone() +func (vc *VCursorImpl) TimeZone() *time.Location { + return vc.SafeSession.TimeZone() } -func (vc *vcursorImpl) SQLMode() string { +func (vc *VCursorImpl) SQLMode() string { // TODO: Implement return the current sql_mode. // This is currently hardcoded to the default in MySQL 8.0. return config.DefaultSQLMode } // MaxMemoryRows returns the maxMemoryRows flag value. -func (vc *vcursorImpl) MaxMemoryRows() int { - return maxMemoryRows +func (vc *VCursorImpl) MaxMemoryRows() int { + return vc.config.MaxMemoryRows } // ExceedsMaxMemoryRows returns a boolean indicating whether the maxMemoryRows value has been exceeded. // Returns false if the max memory rows override directive is set to true. -func (vc *vcursorImpl) ExceedsMaxMemoryRows(numRows int) bool { - return !vc.ignoreMaxMemoryRows && numRows > maxMemoryRows +func (vc *VCursorImpl) ExceedsMaxMemoryRows(numRows int) bool { + return !vc.ignoreMaxMemoryRows && numRows > vc.config.MaxMemoryRows } // SetIgnoreMaxMemoryRows sets the ignoreMaxMemoryRows value. -func (vc *vcursorImpl) SetIgnoreMaxMemoryRows(ignoreMaxMemoryRows bool) { +func (vc *VCursorImpl) SetIgnoreMaxMemoryRows(ignoreMaxMemoryRows bool) { vc.ignoreMaxMemoryRows = ignoreMaxMemoryRows } // RecordWarning stores the given warning in the current session -func (vc *vcursorImpl) RecordWarning(warning *querypb.QueryWarning) { - vc.safeSession.RecordWarning(warning) +func (vc *VCursorImpl) RecordWarning(warning *querypb.QueryWarning) { + vc.SafeSession.RecordWarning(warning) } // IsShardRoutingEnabled implements the VCursor interface. -func (vc *vcursorImpl) IsShardRoutingEnabled() bool { - return enableShardRouting +func (vc *VCursorImpl) IsShardRoutingEnabled() bool { + return vc.config.EnableShardRouting } -func (vc *vcursorImpl) ReadTransaction(ctx context.Context, transactionID string) (*querypb.TransactionMetadata, error) { +func (vc *VCursorImpl) ReadTransaction(ctx context.Context, transactionID string) (*querypb.TransactionMetadata, error) { return vc.executor.ReadTransaction(ctx, transactionID) } // UnresolvedTransactions gets the unresolved transactions for the given keyspace. If the keyspace is not given, // then we use the default keyspace. -func (vc *vcursorImpl) UnresolvedTransactions(ctx context.Context, keyspace string) ([]*querypb.TransactionMetadata, error) { +func (vc *VCursorImpl) UnresolvedTransactions(ctx context.Context, keyspace string) ([]*querypb.TransactionMetadata, error) { if keyspace == "" { keyspace = vc.GetKeyspace() } @@ -286,7 +374,7 @@ func (vc *vcursorImpl) UnresolvedTransactions(ctx context.Context, keyspace stri return vc.executor.UnresolvedTransactions(ctx, targets) } -func (vc *vcursorImpl) StartPrimitiveTrace() func() engine.Stats { +func (vc *VCursorImpl) StartPrimitiveTrace() func() engine.Stats { vc.interOpStats = make(map[engine.Primitive]engine.RowsReceived) vc.shardsStats = make(map[engine.Primitive]engine.ShardsQueried) return func() engine.Stats { @@ -299,8 +387,8 @@ func (vc *vcursorImpl) StartPrimitiveTrace() func() engine.Stats { // FindTable finds the specified table. If the keyspace what specified in the input, it gets used as qualifier. // Otherwise, the keyspace from the request is used, if one was provided. -func (vc *vcursorImpl) FindTable(name sqlparser.TableName) (*vindexes.Table, string, topodatapb.TabletType, key.Destination, error) { - destKeyspace, destTabletType, dest, err := vc.executor.ParseDestinationTarget(name.Qualifier.String()) +func (vc *VCursorImpl) FindTable(name sqlparser.TableName) (*vindexes.Table, string, topodatapb.TabletType, key.Destination, error) { + destKeyspace, destTabletType, dest, err := vc.ParseDestinationTarget(name.Qualifier.String()) if err != nil { return nil, "", destTabletType, nil, err } @@ -314,8 +402,8 @@ func (vc *vcursorImpl) FindTable(name sqlparser.TableName) (*vindexes.Table, str return table, destKeyspace, destTabletType, dest, err } -func (vc *vcursorImpl) FindView(name sqlparser.TableName) sqlparser.SelectStatement { - ks, _, _, err := vc.executor.ParseDestinationTarget(name.Qualifier.String()) +func (vc *VCursorImpl) FindView(name sqlparser.TableName) sqlparser.SelectStatement { + ks, _, _, err := vc.ParseDestinationTarget(name.Qualifier.String()) if err != nil { return nil } @@ -325,8 +413,8 @@ func (vc *vcursorImpl) FindView(name sqlparser.TableName) sqlparser.SelectStatem return vc.vschema.FindView(ks, name.Name.String()) } -func (vc *vcursorImpl) FindRoutedTable(name sqlparser.TableName) (*vindexes.Table, error) { - destKeyspace, destTabletType, _, err := vc.executor.ParseDestinationTarget(name.Qualifier.String()) +func (vc *VCursorImpl) FindRoutedTable(name sqlparser.TableName) (*vindexes.Table, error) { + destKeyspace, destTabletType, _, err := vc.ParseDestinationTarget(name.Qualifier.String()) if err != nil { return nil, err } @@ -343,14 +431,14 @@ func (vc *vcursorImpl) FindRoutedTable(name sqlparser.TableName) (*vindexes.Tabl } // FindTableOrVindex finds the specified table or vindex. -func (vc *vcursorImpl) FindTableOrVindex(name sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { +func (vc *VCursorImpl) FindTableOrVindex(name sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { if name.Qualifier.IsEmpty() && name.Name.String() == "dual" { // The magical MySQL dual table should only be resolved // when it is not qualified by a database name. return vc.getDualTable() } - destKeyspace, destTabletType, dest, err := vc.executor.ParseDestinationTarget(name.Qualifier.String()) + destKeyspace, destTabletType, dest, err := ParseDestinationTarget(name.Qualifier.String(), vc.tabletType, vc.vschema) if err != nil { return nil, nil, "", destTabletType, nil, err } @@ -364,7 +452,23 @@ func (vc *vcursorImpl) FindTableOrVindex(name sqlparser.TableName) (*vindexes.Ta return table, vindex, destKeyspace, destTabletType, dest, nil } -func (vc *vcursorImpl) getDualTable() (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { +func (vc *VCursorImpl) ParseDestinationTarget(targetString string) (string, topodatapb.TabletType, key.Destination, error) { + return ParseDestinationTarget(targetString, vc.tabletType, vc.vschema) +} + +// ParseDestinationTarget parses destination target string and provides a keyspace if possible. +func ParseDestinationTarget(targetString string, tablet topodatapb.TabletType, vschema *vindexes.VSchema) (string, topodatapb.TabletType, key.Destination, error) { + destKeyspace, destTabletType, dest, err := topoprotopb.ParseDestination(targetString, tablet) + // If the keyspace is not specified, and there is only one keyspace in the VSchema, use that. + if destKeyspace == "" && len(vschema.Keyspaces) == 1 { + for k := range vschema.Keyspaces { + destKeyspace = k + } + } + return destKeyspace, destTabletType, dest, err +} + +func (vc *VCursorImpl) getDualTable() (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) { ksName := vc.getActualKeyspace() var ks *vindexes.Keyspace if ksName == "" { @@ -381,7 +485,7 @@ func (vc *vcursorImpl) getDualTable() (*vindexes.Table, vindexes.Vindex, string, return tbl, nil, ksName, topodatapb.TabletType_PRIMARY, nil, nil } -func (vc *vcursorImpl) getActualKeyspace() string { +func (vc *VCursorImpl) getActualKeyspace() string { if !sqlparser.SystemSchema(vc.keyspace) { return vc.keyspace } @@ -395,9 +499,9 @@ func (vc *vcursorImpl) getActualKeyspace() string { // SelectedKeyspace returns the selected keyspace of the current request // if there is one. If the keyspace specified in the target cannot be // identified, it returns an error. -func (vc *vcursorImpl) SelectedKeyspace() (*vindexes.Keyspace, error) { +func (vc *VCursorImpl) SelectedKeyspace() (*vindexes.Keyspace, error) { if ignoreKeyspace(vc.keyspace) { - return nil, errNoKeyspace + return nil, ErrNoKeyspace } ks, ok := vc.vschema.Keyspaces[vc.keyspace] if !ok { @@ -408,12 +512,12 @@ func (vc *vcursorImpl) SelectedKeyspace() (*vindexes.Keyspace, error) { var errNoDbAvailable = vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.NoDB, "no database available") -func (vc *vcursorImpl) AnyKeyspace() (*vindexes.Keyspace, error) { +func (vc *VCursorImpl) AnyKeyspace() (*vindexes.Keyspace, error) { keyspace, err := vc.SelectedKeyspace() if err == nil { return keyspace, nil } - if err != errNoKeyspace { + if err != ErrNoKeyspace { return nil, err } @@ -434,7 +538,7 @@ func (vc *vcursorImpl) AnyKeyspace() (*vindexes.Keyspace, error) { } // getSortedServingKeyspaces gets the sorted serving keyspaces -func (vc *vcursorImpl) getSortedServingKeyspaces() []*vindexes.Keyspace { +func (vc *VCursorImpl) getSortedServingKeyspaces() []*vindexes.Keyspace { var keyspaces []*vindexes.Keyspace if vc.resolver != nil && vc.resolver.GetGateway() != nil { @@ -458,7 +562,7 @@ func (vc *vcursorImpl) getSortedServingKeyspaces() []*vindexes.Keyspace { return keyspaces } -func (vc *vcursorImpl) FirstSortedKeyspace() (*vindexes.Keyspace, error) { +func (vc *VCursorImpl) FirstSortedKeyspace() (*vindexes.Keyspace, error) { if len(vc.vschema.Keyspaces) == 0 { return nil, errNoDbAvailable } @@ -468,17 +572,17 @@ func (vc *vcursorImpl) FirstSortedKeyspace() (*vindexes.Keyspace, error) { } // SysVarSetEnabled implements the ContextVSchema interface -func (vc *vcursorImpl) SysVarSetEnabled() bool { +func (vc *VCursorImpl) SysVarSetEnabled() bool { return vc.GetSessionEnableSystemSettings() } // KeyspaceExists provides whether the keyspace exists or not. -func (vc *vcursorImpl) KeyspaceExists(ks string) bool { +func (vc *VCursorImpl) KeyspaceExists(ks string) bool { return vc.vschema.Keyspaces[ks] != nil } // AllKeyspace implements the ContextVSchema interface -func (vc *vcursorImpl) AllKeyspace() ([]*vindexes.Keyspace, error) { +func (vc *VCursorImpl) AllKeyspace() ([]*vindexes.Keyspace, error) { if len(vc.vschema.Keyspaces) == 0 { return nil, errNoDbAvailable } @@ -490,7 +594,7 @@ func (vc *vcursorImpl) AllKeyspace() ([]*vindexes.Keyspace, error) { } // FindKeyspace implements the VSchema interface -func (vc *vcursorImpl) FindKeyspace(keyspace string) (*vindexes.Keyspace, error) { +func (vc *VCursorImpl) FindKeyspace(keyspace string) (*vindexes.Keyspace, error) { if len(vc.vschema.Keyspaces) == 0 { return nil, errNoDbAvailable } @@ -503,28 +607,28 @@ func (vc *vcursorImpl) FindKeyspace(keyspace string) (*vindexes.Keyspace, error) } // Planner implements the ContextVSchema interface -func (vc *vcursorImpl) Planner() plancontext.PlannerVersion { - if vc.safeSession.Options != nil && - vc.safeSession.Options.PlannerVersion != querypb.ExecuteOptions_DEFAULT_PLANNER { - return vc.safeSession.Options.PlannerVersion +func (vc *VCursorImpl) Planner() plancontext.PlannerVersion { + if vc.SafeSession.Options != nil && + vc.SafeSession.Options.PlannerVersion != querypb.ExecuteOptions_DEFAULT_PLANNER { + return vc.SafeSession.Options.PlannerVersion } - return vc.pv + return vc.config.PlannerVersion } // GetSemTable implements the ContextVSchema interface -func (vc *vcursorImpl) GetSemTable() *semantics.SemTable { +func (vc *VCursorImpl) GetSemTable() *semantics.SemTable { return vc.semTable } // TargetString returns the current TargetString of the session. -func (vc *vcursorImpl) TargetString() string { - return vc.safeSession.TargetString +func (vc *VCursorImpl) TargetString() string { + return vc.SafeSession.TargetString } // MaxBufferingRetries is to represent max retries on buffering. const MaxBufferingRetries = 3 -func (vc *vcursorImpl) ExecutePrimitive(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { +func (vc *VCursorImpl) ExecutePrimitive(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { for try := 0; try < MaxBufferingRetries; try++ { res, err := primitive.TryExecute(ctx, vc, bindVars, wantfields) if err != nil && vterrors.RootCause(err) == buffer.ShardMissingError { @@ -536,7 +640,7 @@ func (vc *vcursorImpl) ExecutePrimitive(ctx context.Context, primitive engine.Pr return nil, vterrors.New(vtrpcpb.Code_UNAVAILABLE, "upstream shards are not available") } -func (vc *vcursorImpl) logOpTraffic(primitive engine.Primitive, res *sqltypes.Result) { +func (vc *VCursorImpl) logOpTraffic(primitive engine.Primitive, res *sqltypes.Result) { if vc.interOpStats != nil { rows := vc.interOpStats[primitive] if res == nil { @@ -548,14 +652,14 @@ func (vc *vcursorImpl) logOpTraffic(primitive engine.Primitive, res *sqltypes.Re } } -func (vc *vcursorImpl) logShardsQueried(primitive engine.Primitive, shardsNb int) { +func (vc *VCursorImpl) logShardsQueried(primitive engine.Primitive, shardsNb int) { if vc.shardsStats != nil { vc.shardsStats[primitive] += engine.ShardsQueried(shardsNb) } } -func (vc *vcursorImpl) ExecutePrimitiveStandalone(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - // clone the vcursorImpl with a new session. +func (vc *VCursorImpl) ExecutePrimitiveStandalone(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + // clone the VCursorImpl with a new session. newVC := vc.cloneWithAutocommitSession() for try := 0; try < MaxBufferingRetries; try++ { res, err := primitive.TryExecute(ctx, newVC, bindVars, wantfields) @@ -568,7 +672,7 @@ func (vc *vcursorImpl) ExecutePrimitiveStandalone(ctx context.Context, primitive return nil, vterrors.New(vtrpcpb.Code_UNAVAILABLE, "upstream shards are not available") } -func (vc *vcursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primitive engine.Primitive) func(*sqltypes.Result) error { +func (vc *VCursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primitive engine.Primitive) func(*sqltypes.Result) error { if vc.interOpStats == nil { return callback } @@ -579,7 +683,7 @@ func (vc *vcursorImpl) wrapCallback(callback func(*sqltypes.Result) error, primi } } -func (vc *vcursorImpl) StreamExecutePrimitive(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { +func (vc *VCursorImpl) StreamExecutePrimitive(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { callback = vc.wrapCallback(callback, primitive) for try := 0; try < MaxBufferingRetries; try++ { @@ -592,10 +696,10 @@ func (vc *vcursorImpl) StreamExecutePrimitive(ctx context.Context, primitive eng return vterrors.New(vtrpcpb.Code_UNAVAILABLE, "upstream shards are not available") } -func (vc *vcursorImpl) StreamExecutePrimitiveStandalone(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(result *sqltypes.Result) error) error { +func (vc *VCursorImpl) StreamExecutePrimitiveStandalone(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(result *sqltypes.Result) error) error { callback = vc.wrapCallback(callback, primitive) - // clone the vcursorImpl with a new session. + // clone the VCursorImpl with a new session. newVC := vc.cloneWithAutocommitSession() for try := 0; try < MaxBufferingRetries; try++ { err := primitive.TryStreamExecute(ctx, newVC, bindVars, wantfields, callback) @@ -608,12 +712,11 @@ func (vc *vcursorImpl) StreamExecutePrimitiveStandalone(ctx context.Context, pri } // Execute is part of the engine.VCursor interface. -func (vc *vcursorImpl) Execute(ctx context.Context, method string, query string, bindVars map[string]*querypb.BindVariable, rollbackOnError bool, co vtgatepb.CommitOrder) (*sqltypes.Result, error) { - session := vc.safeSession +func (vc *VCursorImpl) Execute(ctx context.Context, method string, query string, bindVars map[string]*querypb.BindVariable, rollbackOnError bool, co vtgatepb.CommitOrder) (*sqltypes.Result, error) { + session := vc.SafeSession if co == vtgatepb.CommitOrder_AUTOCOMMIT { // For autocommit, we have to create an independent session. - session = NewAutocommitSession(vc.safeSession.Session) - session.logging = vc.safeSession.logging + session = vc.SafeSession.NewAutocommitSession() rollbackOnError = false } else { session.SetCommitOrder(co) @@ -634,24 +737,22 @@ func (vc *vcursorImpl) Execute(ctx context.Context, method string, query string, // markSavepoint opens an internal savepoint before executing the original query. // This happens only when rollback is allowed and no other savepoint was executed // and the query is executed in an explicit transaction (i.e. started by the client). -func (vc *vcursorImpl) markSavepoint(ctx context.Context, needsRollbackOnParialExec bool, bindVars map[string]*querypb.BindVariable) error { - if !needsRollbackOnParialExec || !vc.safeSession.CanAddSavepoint() { +func (vc *VCursorImpl) markSavepoint(ctx context.Context, needsRollbackOnParialExec bool, bindVars map[string]*querypb.BindVariable) error { + if !needsRollbackOnParialExec || !vc.SafeSession.CanAddSavepoint() { return nil } uID := fmt.Sprintf("_vt%s", strings.ReplaceAll(uuid.NewString(), "-", "_")) spQuery := fmt.Sprintf("%ssavepoint %s%s", vc.marginComments.Leading, uID, vc.marginComments.Trailing) - _, err := vc.executor.Execute(ctx, nil, "MarkSavepoint", vc.safeSession, spQuery, bindVars) + _, err := vc.executor.Execute(ctx, nil, "MarkSavepoint", vc.SafeSession, spQuery, bindVars) if err != nil { return err } - vc.safeSession.SetSavepoint(uID) + vc.SafeSession.SetSavepoint(uID) return nil } -const txRollback = "Rollback Transaction" - // ExecuteMultiShard is part of the engine.VCursor interface. -func (vc *vcursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) (*sqltypes.Result, []error) { +func (vc *VCursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) (*sqltypes.Result, []error) { noOfShards := len(rss) atomic.AddUint64(&vc.logStats.ShardQueries, uint64(noOfShards)) err := vc.markSavepoint(ctx, rollbackOnError && (noOfShards > 1), map[string]*querypb.BindVariable{}) @@ -659,14 +760,14 @@ func (vc *vcursorImpl) ExecuteMultiShard(ctx context.Context, primitive engine.P return nil, []error{err} } - qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.safeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.resultsObserver) + qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, commentedShardQueries(queries, vc.marginComments), vc.SafeSession, canAutocommit, vc.ignoreMaxMemoryRows, vc.observer) vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError) vc.logShardsQueried(primitive, len(rss)) return qr, errs } // StreamExecuteMulti is the streaming version of ExecuteMultiShard. -func (vc *vcursorImpl) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, bindVars []map[string]*querypb.BindVariable, rollbackOnError bool, autocommit bool, callback func(reply *sqltypes.Result) error) []error { +func (vc *VCursorImpl) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, bindVars []map[string]*querypb.BindVariable, rollbackOnError bool, autocommit bool, callback func(reply *sqltypes.Result) error) []error { callback = vc.wrapCallback(callback, primitive) noOfShards := len(rss) @@ -676,20 +777,20 @@ func (vc *vcursorImpl) StreamExecuteMulti(ctx context.Context, primitive engine. return []error{err} } - errs := vc.executor.StreamExecuteMulti(ctx, primitive, vc.marginComments.Leading+query+vc.marginComments.Trailing, rss, bindVars, vc.safeSession, autocommit, callback, vc.resultsObserver) + errs := vc.executor.StreamExecuteMulti(ctx, primitive, vc.marginComments.Leading+query+vc.marginComments.Trailing, rss, bindVars, vc.SafeSession, autocommit, callback, vc.observer) vc.setRollbackOnPartialExecIfRequired(len(errs) != len(rss), rollbackOnError) return errs } // ExecuteLock is for executing advisory lock statements. -func (vc *vcursorImpl) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { +func (vc *VCursorImpl) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { query.Sql = vc.marginComments.Leading + query.Sql + vc.marginComments.Trailing - return vc.executor.ExecuteLock(ctx, rs, query, vc.safeSession, lockFuncType) + return vc.executor.ExecuteLock(ctx, rs, query, vc.SafeSession, lockFuncType) } // ExecuteStandalone is part of the engine.VCursor interface. -func (vc *vcursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.Primitive, query string, bindVars map[string]*querypb.BindVariable, rs *srvtopo.ResolvedShard) (*sqltypes.Result, error) { +func (vc *VCursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.Primitive, query string, bindVars map[string]*querypb.BindVariable, rs *srvtopo.ResolvedShard) (*sqltypes.Result, error) { rss := []*srvtopo.ResolvedShard{rs} bqs := []*querypb.BoundQuery{ { @@ -699,13 +800,13 @@ func (vc *vcursorImpl) ExecuteStandalone(ctx context.Context, primitive engine.P } // The autocommit flag is always set to false because we currently don't // execute DMLs through ExecuteStandalone. - qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.safeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.resultsObserver) + qr, errs := vc.executor.ExecuteMultiShard(ctx, primitive, rss, bqs, NewAutocommitSession(vc.SafeSession.Session), false /* autocommit */, vc.ignoreMaxMemoryRows, vc.observer) vc.logShardsQueried(primitive, len(rss)) return qr, vterrors.Aggregate(errs) } // ExecuteKeyspaceID is part of the engine.VCursor interface. -func (vc *vcursorImpl) ExecuteKeyspaceID(ctx context.Context, keyspace string, ksid []byte, query string, bindVars map[string]*querypb.BindVariable, rollbackOnError, autocommit bool) (*sqltypes.Result, error) { +func (vc *VCursorImpl) ExecuteKeyspaceID(ctx context.Context, keyspace string, ksid []byte, query string, bindVars map[string]*querypb.BindVariable, rollbackOnError, autocommit bool) (*sqltypes.Result, error) { atomic.AddUint64(&vc.logStats.ShardQueries, 1) rss, _, err := vc.ResolveDestinations(ctx, keyspace, nil, []key.Destination{key.DestinationKeyspaceID(ksid)}) if err != nil { @@ -722,17 +823,17 @@ func (vc *vcursorImpl) ExecuteKeyspaceID(ctx context.Context, keyspace string, k // This creates a transaction but that transaction is for locking purpose only and should not cause multi-db transaction error. // This fields helps in to ignore multi-db transaction error when it states `queryFromVindex`. if !rollbackOnError { - vc.safeSession.queryFromVindex = true + vc.SafeSession.SetQueryFromVindex(true) defer func() { - vc.safeSession.queryFromVindex = false + vc.SafeSession.SetQueryFromVindex(false) }() } qr, errs := vc.ExecuteMultiShard(ctx, nil, rss, queries, rollbackOnError, autocommit) return qr, vterrors.Aggregate(errs) } -func (vc *vcursorImpl) InTransactionAndIsDML() bool { - if !vc.safeSession.InTransaction() { +func (vc *VCursorImpl) InTransactionAndIsDML() bool { + if !vc.SafeSession.InTransaction() { return false } switch vc.logStats.StmtType { @@ -742,7 +843,7 @@ func (vc *vcursorImpl) InTransactionAndIsDML() bool { return false } -func (vc *vcursorImpl) LookupRowLockShardSession() vtgatepb.CommitOrder { +func (vc *VCursorImpl) LookupRowLockShardSession() vtgatepb.CommitOrder { switch vc.logStats.StmtType { case "DELETE", "UPDATE": return vtgatepb.CommitOrder_POST @@ -751,23 +852,23 @@ func (vc *vcursorImpl) LookupRowLockShardSession() vtgatepb.CommitOrder { } // AutocommitApproval is part of the engine.VCursor interface. -func (vc *vcursorImpl) AutocommitApproval() bool { - return vc.safeSession.AutocommitApproval() +func (vc *VCursorImpl) AutocommitApproval() bool { + return vc.SafeSession.AutocommitApproval() } // setRollbackOnPartialExecIfRequired sets the value on SafeSession.rollbackOnPartialExec // when the query gets successfully executed on at least one shard, // there does not exist any old savepoint for which rollback is already set // and rollback on error is allowed. -func (vc *vcursorImpl) setRollbackOnPartialExecIfRequired(atleastOneSuccess bool, rollbackOnError bool) { - if atleastOneSuccess && rollbackOnError && !vc.safeSession.IsRollbackSet() { - vc.safeSession.SetRollbackCommand() +func (vc *VCursorImpl) setRollbackOnPartialExecIfRequired(atleastOneSuccess bool, rollbackOnError bool) { + if atleastOneSuccess && rollbackOnError && !vc.SafeSession.IsRollbackSet() { + vc.SafeSession.SetRollbackCommand() } } // fixupPartiallyMovedShards checks if any of the shards in the route has a ShardRoutingRule (true when a keyspace // is in the middle of being moved to another keyspace using MoveTables moving a subset of shards at a time -func (vc *vcursorImpl) fixupPartiallyMovedShards(rss []*srvtopo.ResolvedShard) ([]*srvtopo.ResolvedShard, error) { +func (vc *VCursorImpl) fixupPartiallyMovedShards(rss []*srvtopo.ResolvedShard) ([]*srvtopo.ResolvedShard, error) { if vc.vschema.ShardRoutingRules == nil { return rss, nil } @@ -784,12 +885,12 @@ func (vc *vcursorImpl) fixupPartiallyMovedShards(rss []*srvtopo.ResolvedShard) ( return rss, nil } -func (vc *vcursorImpl) ResolveDestinations(ctx context.Context, keyspace string, ids []*querypb.Value, destinations []key.Destination) ([]*srvtopo.ResolvedShard, [][]*querypb.Value, error) { +func (vc *VCursorImpl) ResolveDestinations(ctx context.Context, keyspace string, ids []*querypb.Value, destinations []key.Destination) ([]*srvtopo.ResolvedShard, [][]*querypb.Value, error) { rss, values, err := vc.resolver.ResolveDestinations(ctx, keyspace, vc.tabletType, ids, destinations) if err != nil { return nil, nil, err } - if enableShardRouting { + if vc.config.EnableShardRouting { rss, err = vc.fixupPartiallyMovedShards(rss) if err != nil { return nil, nil, err @@ -798,12 +899,12 @@ func (vc *vcursorImpl) ResolveDestinations(ctx context.Context, keyspace string, return rss, values, err } -func (vc *vcursorImpl) ResolveDestinationsMultiCol(ctx context.Context, keyspace string, ids [][]sqltypes.Value, destinations []key.Destination) ([]*srvtopo.ResolvedShard, [][][]sqltypes.Value, error) { +func (vc *VCursorImpl) ResolveDestinationsMultiCol(ctx context.Context, keyspace string, ids [][]sqltypes.Value, destinations []key.Destination) ([]*srvtopo.ResolvedShard, [][][]sqltypes.Value, error) { rss, values, err := vc.resolver.ResolveDestinationsMultiCol(ctx, keyspace, vc.tabletType, ids, destinations) if err != nil { return nil, nil, err } - if enableShardRouting { + if vc.config.EnableShardRouting { rss, err = vc.fixupPartiallyMovedShards(rss) if err != nil { return nil, nil, err @@ -812,12 +913,12 @@ func (vc *vcursorImpl) ResolveDestinationsMultiCol(ctx context.Context, keyspace return rss, values, err } -func (vc *vcursorImpl) Session() engine.SessionActions { +func (vc *VCursorImpl) Session() engine.SessionActions { return vc } -func (vc *vcursorImpl) SetTarget(target string) error { - keyspace, tabletType, _, err := topoprotopb.ParseDestination(target, defaultTabletType) +func (vc *VCursorImpl) SetTarget(target string) error { + keyspace, tabletType, _, err := topoprotopb.ParseDestination(target, vc.config.DefaultTabletType) if err != nil { return err } @@ -825,10 +926,12 @@ func (vc *vcursorImpl) SetTarget(target string) error { return vterrors.VT05003(keyspace) } - if vc.safeSession.InTransaction() && tabletType != topodatapb.TabletType_PRIMARY { + if vc.SafeSession.InTransaction() && tabletType != topodatapb.TabletType_PRIMARY { return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.LockOrActiveTransaction, "can't execute the given command because you have an active transaction") } - vc.safeSession.SetTargetString(target) + vc.SafeSession.SetTargetString(target) + vc.keyspace = keyspace + vc.tabletType = tabletType return nil } @@ -836,30 +939,30 @@ func ignoreKeyspace(keyspace string) bool { return keyspace == "" || sqlparser.SystemSchema(keyspace) } -func (vc *vcursorImpl) SetUDV(key string, value any) error { +func (vc *VCursorImpl) SetUDV(key string, value any) error { bindValue, err := sqltypes.BuildBindVariable(value) if err != nil { return err } - vc.safeSession.SetUserDefinedVariable(key, bindValue) + vc.SafeSession.SetUserDefinedVariable(key, bindValue) return nil } -func (vc *vcursorImpl) SetSysVar(name string, expr string) { - vc.safeSession.SetSystemVariable(name, expr) +func (vc *VCursorImpl) SetSysVar(name string, expr string) { + vc.SafeSession.SetSystemVariable(name, expr) } // NeedsReservedConn implements the SessionActions interface -func (vc *vcursorImpl) NeedsReservedConn() { - vc.safeSession.SetReservedConn(true) +func (vc *VCursorImpl) NeedsReservedConn() { + vc.SafeSession.SetReservedConn(true) } -func (vc *vcursorImpl) InReservedConn() bool { - return vc.safeSession.InReservedConn() +func (vc *VCursorImpl) InReservedConn() bool { + return vc.SafeSession.InReservedConn() } -func (vc *vcursorImpl) ShardSession() []*srvtopo.ResolvedShard { - ss := vc.safeSession.GetShardSessions() +func (vc *VCursorImpl) ShardSession() []*srvtopo.ResolvedShard { + ss := vc.SafeSession.GetShardSessions() if len(ss) == 0 { return nil } @@ -874,12 +977,12 @@ func (vc *vcursorImpl) ShardSession() []*srvtopo.ResolvedShard { } // Destination implements the ContextVSchema interface -func (vc *vcursorImpl) Destination() key.Destination { +func (vc *VCursorImpl) Destination() key.Destination { return vc.destination } // TabletType implements the ContextVSchema interface -func (vc *vcursorImpl) TabletType() topodatapb.TabletType { +func (vc *VCursorImpl) TabletType() topodatapb.TabletType { return vc.tabletType } @@ -898,13 +1001,13 @@ func commentedShardQueries(shardQueries []*querypb.BoundQuery, marginComments sq } // TargetDestination implements the ContextVSchema interface -func (vc *vcursorImpl) TargetDestination(qualifier string) (key.Destination, *vindexes.Keyspace, topodatapb.TabletType, error) { +func (vc *VCursorImpl) TargetDestination(qualifier string) (key.Destination, *vindexes.Keyspace, topodatapb.TabletType, error) { keyspaceName := vc.getActualKeyspace() if vc.destination == nil && qualifier != "" { keyspaceName = qualifier } if keyspaceName == "" { - return nil, nil, 0, errNoKeyspace + return nil, nil, 0, ErrNoKeyspace } keyspace := vc.vschema.Keyspaces[keyspaceName] if keyspace == nil { @@ -914,63 +1017,63 @@ func (vc *vcursorImpl) TargetDestination(qualifier string) (key.Destination, *vi } // SetAutocommit implements the SessionActions interface -func (vc *vcursorImpl) SetAutocommit(ctx context.Context, autocommit bool) error { - if autocommit && vc.safeSession.InTransaction() { - if err := vc.executor.Commit(ctx, vc.safeSession); err != nil { +func (vc *VCursorImpl) SetAutocommit(ctx context.Context, autocommit bool) error { + if autocommit && vc.SafeSession.InTransaction() { + if err := vc.executor.Commit(ctx, vc.SafeSession); err != nil { return err } } - vc.safeSession.Autocommit = autocommit + vc.SafeSession.Autocommit = autocommit return nil } // SetQueryTimeout implements the SessionActions interface -func (vc *vcursorImpl) SetQueryTimeout(maxExecutionTime int64) { - vc.safeSession.QueryTimeout = maxExecutionTime +func (vc *VCursorImpl) SetQueryTimeout(maxExecutionTime int64) { + vc.SafeSession.QueryTimeout = maxExecutionTime } // SetClientFoundRows implements the SessionActions interface -func (vc *vcursorImpl) SetClientFoundRows(_ context.Context, clientFoundRows bool) error { - vc.safeSession.GetOrCreateOptions().ClientFoundRows = clientFoundRows +func (vc *VCursorImpl) SetClientFoundRows(_ context.Context, clientFoundRows bool) error { + vc.SafeSession.GetOrCreateOptions().ClientFoundRows = clientFoundRows return nil } // SetSkipQueryPlanCache implements the SessionActions interface -func (vc *vcursorImpl) SetSkipQueryPlanCache(_ context.Context, skipQueryPlanCache bool) error { - vc.safeSession.GetOrCreateOptions().SkipQueryPlanCache = skipQueryPlanCache +func (vc *VCursorImpl) SetSkipQueryPlanCache(_ context.Context, skipQueryPlanCache bool) error { + vc.SafeSession.GetOrCreateOptions().SkipQueryPlanCache = skipQueryPlanCache return nil } // SetSQLSelectLimit implements the SessionActions interface -func (vc *vcursorImpl) SetSQLSelectLimit(limit int64) error { - vc.safeSession.GetOrCreateOptions().SqlSelectLimit = limit +func (vc *VCursorImpl) SetSQLSelectLimit(limit int64) error { + vc.SafeSession.GetOrCreateOptions().SqlSelectLimit = limit return nil } // SetTransactionMode implements the SessionActions interface -func (vc *vcursorImpl) SetTransactionMode(mode vtgatepb.TransactionMode) { - vc.safeSession.TransactionMode = mode +func (vc *VCursorImpl) SetTransactionMode(mode vtgatepb.TransactionMode) { + vc.SafeSession.TransactionMode = mode } // SetWorkload implements the SessionActions interface -func (vc *vcursorImpl) SetWorkload(workload querypb.ExecuteOptions_Workload) { - vc.safeSession.GetOrCreateOptions().Workload = workload +func (vc *VCursorImpl) SetWorkload(workload querypb.ExecuteOptions_Workload) { + vc.SafeSession.GetOrCreateOptions().Workload = workload } // SetPlannerVersion implements the SessionActions interface -func (vc *vcursorImpl) SetPlannerVersion(v plancontext.PlannerVersion) { - vc.safeSession.GetOrCreateOptions().PlannerVersion = v +func (vc *VCursorImpl) SetPlannerVersion(v plancontext.PlannerVersion) { + vc.SafeSession.GetOrCreateOptions().PlannerVersion = v } -func (vc *vcursorImpl) SetPriority(priority string) { +func (vc *VCursorImpl) SetPriority(priority string) { if priority != "" { - vc.safeSession.GetOrCreateOptions().Priority = priority - } else if vc.safeSession.Options != nil && vc.safeSession.Options.Priority != "" { - vc.safeSession.Options.Priority = "" + vc.SafeSession.GetOrCreateOptions().Priority = priority + } else if vc.SafeSession.Options != nil && vc.SafeSession.Options.Priority != "" { + vc.SafeSession.Options.Priority = "" } } -func (vc *vcursorImpl) SetExecQueryTimeout(timeout *int) { +func (vc *VCursorImpl) SetExecQueryTimeout(timeout *int) { // Determine the effective timeout: use passed timeout if non-nil, otherwise use session's query timeout if available var execTimeout *int if timeout != nil { @@ -981,153 +1084,152 @@ func (vc *vcursorImpl) SetExecQueryTimeout(timeout *int) { // If no effective timeout and no session options, return early if execTimeout == nil { - if vc.safeSession.GetOptions() == nil { + if vc.SafeSession.GetOptions() == nil { return } - vc.safeSession.GetOrCreateOptions().Timeout = nil + vc.SafeSession.GetOrCreateOptions().Timeout = nil return } vc.queryTimeout = time.Duration(*execTimeout) * time.Millisecond // Set the authoritative timeout using the determined execTimeout - vc.safeSession.GetOrCreateOptions().Timeout = &querypb.ExecuteOptions_AuthoritativeTimeout{ + vc.SafeSession.GetOrCreateOptions().Timeout = &querypb.ExecuteOptions_AuthoritativeTimeout{ AuthoritativeTimeout: int64(*execTimeout), } } // getQueryTimeout returns timeout based on the priority // session setting > global default specified by a flag. -func (vc *vcursorImpl) getQueryTimeout() int { - sessionQueryTimeout := int(vc.safeSession.GetQueryTimeout()) +func (vc *VCursorImpl) getQueryTimeout() int { + sessionQueryTimeout := int(vc.SafeSession.GetQueryTimeout()) if sessionQueryTimeout != 0 { return sessionQueryTimeout } - return queryTimeout + return vc.config.QueryTimeout } // SetConsolidator implements the SessionActions interface -func (vc *vcursorImpl) SetConsolidator(consolidator querypb.ExecuteOptions_Consolidator) { +func (vc *VCursorImpl) SetConsolidator(consolidator querypb.ExecuteOptions_Consolidator) { // Avoid creating session Options when they do not yet exist and the // consolidator is unspecified. - if consolidator == querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED && vc.safeSession.GetOptions() == nil { + if consolidator == querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED && vc.SafeSession.GetOptions() == nil { return } - vc.safeSession.GetOrCreateOptions().Consolidator = consolidator + vc.SafeSession.GetOrCreateOptions().Consolidator = consolidator } -func (vc *vcursorImpl) SetWorkloadName(workloadName string) { +func (vc *VCursorImpl) SetWorkloadName(workloadName string) { if workloadName != "" { - vc.safeSession.GetOrCreateOptions().WorkloadName = workloadName + vc.SafeSession.GetOrCreateOptions().WorkloadName = workloadName } } // SetFoundRows implements the SessionActions interface -func (vc *vcursorImpl) SetFoundRows(foundRows uint64) { - vc.safeSession.FoundRows = foundRows - vc.safeSession.foundRowsHandled = true +func (vc *VCursorImpl) SetFoundRows(foundRows uint64) { + vc.SafeSession.SetFoundRows(foundRows) } // SetDDLStrategy implements the SessionActions interface -func (vc *vcursorImpl) SetDDLStrategy(strategy string) { - vc.safeSession.SetDDLStrategy(strategy) +func (vc *VCursorImpl) SetDDLStrategy(strategy string) { + vc.SafeSession.SetDDLStrategy(strategy) } // GetDDLStrategy implements the SessionActions interface -func (vc *vcursorImpl) GetDDLStrategy() string { - return vc.safeSession.GetDDLStrategy() +func (vc *VCursorImpl) GetDDLStrategy() string { + return vc.SafeSession.GetDDLStrategy() } // SetMigrationContext implements the SessionActions interface -func (vc *vcursorImpl) SetMigrationContext(migrationContext string) { - vc.safeSession.SetMigrationContext(migrationContext) +func (vc *VCursorImpl) SetMigrationContext(migrationContext string) { + vc.SafeSession.SetMigrationContext(migrationContext) } // GetMigrationContext implements the SessionActions interface -func (vc *vcursorImpl) GetMigrationContext() string { - return vc.safeSession.GetMigrationContext() +func (vc *VCursorImpl) GetMigrationContext() string { + return vc.SafeSession.GetMigrationContext() } // GetSessionUUID implements the SessionActions interface -func (vc *vcursorImpl) GetSessionUUID() string { - return vc.safeSession.GetSessionUUID() +func (vc *VCursorImpl) GetSessionUUID() string { + return vc.SafeSession.GetSessionUUID() } // SetSessionEnableSystemSettings implements the SessionActions interface -func (vc *vcursorImpl) SetSessionEnableSystemSettings(_ context.Context, allow bool) error { - vc.safeSession.SetSessionEnableSystemSettings(allow) +func (vc *VCursorImpl) SetSessionEnableSystemSettings(_ context.Context, allow bool) error { + vc.SafeSession.SetSessionEnableSystemSettings(allow) return nil } // GetSessionEnableSystemSettings implements the SessionActions interface -func (vc *vcursorImpl) GetSessionEnableSystemSettings() bool { - return vc.safeSession.GetSessionEnableSystemSettings() +func (vc *VCursorImpl) GetSessionEnableSystemSettings() bool { + return vc.SafeSession.GetSessionEnableSystemSettings() } // SetReadAfterWriteGTID implements the SessionActions interface -func (vc *vcursorImpl) SetReadAfterWriteGTID(vtgtid string) { - vc.safeSession.SetReadAfterWriteGTID(vtgtid) +func (vc *VCursorImpl) SetReadAfterWriteGTID(vtgtid string) { + vc.SafeSession.SetReadAfterWriteGTID(vtgtid) } // SetReadAfterWriteTimeout implements the SessionActions interface -func (vc *vcursorImpl) SetReadAfterWriteTimeout(timeout float64) { - vc.safeSession.SetReadAfterWriteTimeout(timeout) +func (vc *VCursorImpl) SetReadAfterWriteTimeout(timeout float64) { + vc.SafeSession.SetReadAfterWriteTimeout(timeout) } // SetSessionTrackGTIDs implements the SessionActions interface -func (vc *vcursorImpl) SetSessionTrackGTIDs(enable bool) { - vc.safeSession.SetSessionTrackGtids(enable) +func (vc *VCursorImpl) SetSessionTrackGTIDs(enable bool) { + vc.SafeSession.SetSessionTrackGtids(enable) } // HasCreatedTempTable implements the SessionActions interface -func (vc *vcursorImpl) HasCreatedTempTable() { - vc.safeSession.GetOrCreateOptions().HasCreatedTempTables = true +func (vc *VCursorImpl) HasCreatedTempTable() { + vc.SafeSession.GetOrCreateOptions().HasCreatedTempTables = true } // GetWarnings implements the SessionActions interface -func (vc *vcursorImpl) GetWarnings() []*querypb.QueryWarning { - return vc.safeSession.GetWarnings() +func (vc *VCursorImpl) GetWarnings() []*querypb.QueryWarning { + return vc.SafeSession.GetWarnings() } // AnyAdvisoryLockTaken implements the SessionActions interface -func (vc *vcursorImpl) AnyAdvisoryLockTaken() bool { - return vc.safeSession.HasAdvisoryLock() +func (vc *VCursorImpl) AnyAdvisoryLockTaken() bool { + return vc.SafeSession.HasAdvisoryLock() } // AddAdvisoryLock implements the SessionActions interface -func (vc *vcursorImpl) AddAdvisoryLock(name string) { - vc.safeSession.AddAdvisoryLock(name) +func (vc *VCursorImpl) AddAdvisoryLock(name string) { + vc.SafeSession.AddAdvisoryLock(name) } // RemoveAdvisoryLock implements the SessionActions interface -func (vc *vcursorImpl) RemoveAdvisoryLock(name string) { - vc.safeSession.RemoveAdvisoryLock(name) +func (vc *VCursorImpl) RemoveAdvisoryLock(name string) { + vc.SafeSession.RemoveAdvisoryLock(name) } -func (vc *vcursorImpl) SetCommitOrder(co vtgatepb.CommitOrder) { - vc.safeSession.SetCommitOrder(co) +func (vc *VCursorImpl) SetCommitOrder(co vtgatepb.CommitOrder) { + vc.SafeSession.SetCommitOrder(co) } -func (vc *vcursorImpl) InTransaction() bool { - return vc.safeSession.InTransaction() +func (vc *VCursorImpl) InTransaction() bool { + return vc.SafeSession.InTransaction() } -func (vc *vcursorImpl) Commit(ctx context.Context) error { - return vc.executor.Commit(ctx, vc.safeSession) +func (vc *VCursorImpl) Commit(ctx context.Context) error { + return vc.executor.Commit(ctx, vc.SafeSession) } // GetDBDDLPluginName implements the VCursor interface -func (vc *vcursorImpl) GetDBDDLPluginName() string { - return dbDDLPlugin +func (vc *VCursorImpl) GetDBDDLPluginName() string { + return vc.config.DBDDLPlugin } // KeyspaceAvailable implements the VCursor interface -func (vc *vcursorImpl) KeyspaceAvailable(ks string) bool { +func (vc *VCursorImpl) KeyspaceAvailable(ks string) bool { _, exists := vc.executor.VSchema().Keyspaces[ks] return exists } // ErrorIfShardedF implements the VCursor interface -func (vc *vcursorImpl) ErrorIfShardedF(ks *vindexes.Keyspace, warn, errFormat string, params ...any) error { +func (vc *VCursorImpl) ErrorIfShardedF(ks *vindexes.Keyspace, warn, errFormat string, params ...any) error { if ks.Sharded { return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, errFormat, params...) } @@ -1136,19 +1238,25 @@ func (vc *vcursorImpl) ErrorIfShardedF(ks *vindexes.Keyspace, warn, errFormat st return nil } +func (vc *VCursorImpl) GetAndEmptyWarnings() []*querypb.QueryWarning { + w := vc.warnings + vc.warnings = nil + return w +} + // WarnUnshardedOnly implements the VCursor interface -func (vc *vcursorImpl) WarnUnshardedOnly(format string, params ...any) { - if vc.warnShardedOnly { +func (vc *VCursorImpl) WarnUnshardedOnly(format string, params ...any) { + if vc.config.WarnShardedOnly { vc.warnings = append(vc.warnings, &querypb.QueryWarning{ Code: uint32(sqlerror.ERNotSupportedYet), Message: fmt.Sprintf(format, params...), }) - warnings.Add("WarnUnshardedOnly", 1) + vc.executor.AddWarningCount("WarnUnshardedOnly", 1) } } // PlannerWarning implements the VCursor interface -func (vc *vcursorImpl) PlannerWarning(message string) { +func (vc *VCursorImpl) PlannerWarning(message string) { if message == "" { return } @@ -1159,8 +1267,8 @@ func (vc *vcursorImpl) PlannerWarning(message string) { } // ForeignKeyMode implements the VCursor interface -func (vc *vcursorImpl) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) { - if strings.ToLower(foreignKeyMode) == "disallow" { +func (vc *VCursorImpl) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_ForeignKeyMode, error) { + if vc.config.ForeignKeyMode == vschemapb.Keyspace_disallow { return vschemapb.Keyspace_disallow, nil } ks := vc.vschema.Keyspaces[keyspace] @@ -1170,7 +1278,7 @@ func (vc *vcursorImpl) ForeignKeyMode(keyspace string) (vschemapb.Keyspace_Forei return ks.ForeignKeyMode, nil } -func (vc *vcursorImpl) KeyspaceError(keyspace string) error { +func (vc *VCursorImpl) KeyspaceError(keyspace string) error { ks := vc.vschema.Keyspaces[keyspace] if ks == nil { return vterrors.VT14004(keyspace) @@ -1178,14 +1286,14 @@ func (vc *vcursorImpl) KeyspaceError(keyspace string) error { return ks.Error } -func (vc *vcursorImpl) GetAggregateUDFs() []string { +func (vc *VCursorImpl) GetAggregateUDFs() []string { return vc.vschema.GetAggregateUDFs() } // FindMirrorRule finds the mirror rule for the requested table name and // VSchema tablet type. -func (vc *vcursorImpl) FindMirrorRule(name sqlparser.TableName) (*vindexes.MirrorRule, error) { - destKeyspace, destTabletType, _, err := vc.executor.ParseDestinationTarget(name.Qualifier.String()) +func (vc *VCursorImpl) FindMirrorRule(name sqlparser.TableName) (*vindexes.MirrorRule, error) { + destKeyspace, destTabletType, _, err := vc.ParseDestinationTarget(name.Qualifier.String()) if err != nil { return nil, err } @@ -1199,23 +1307,11 @@ func (vc *vcursorImpl) FindMirrorRule(name sqlparser.TableName) (*vindexes.Mirro return mirrorRule, err } -// ParseDestinationTarget parses destination target string and sets default keyspace if possible. -func parseDestinationTarget(targetString string, vschema *vindexes.VSchema) (string, topodatapb.TabletType, key.Destination, error) { - destKeyspace, destTabletType, dest, err := topoprotopb.ParseDestination(targetString, defaultTabletType) - // Set default keyspace - if destKeyspace == "" && len(vschema.Keyspaces) == 1 { - for k := range vschema.Keyspaces { - destKeyspace = k - } - } - return destKeyspace, destTabletType, dest, err -} - -func (vc *vcursorImpl) keyForPlan(ctx context.Context, query string, buf io.StringWriter) { +func (vc *VCursorImpl) KeyForPlan(ctx context.Context, query string, buf io.StringWriter) { _, _ = buf.WriteString(vc.keyspace) _, _ = buf.WriteString(vindexes.TabletTypeSuffix[vc.tabletType]) _, _ = buf.WriteString("+Collate:") - _, _ = buf.WriteString(vc.Environment().CollationEnv().LookupName(vc.collation)) + _, _ = buf.WriteString(vc.Environment().CollationEnv().LookupName(vc.config.Collation)) if vc.destination != nil { switch vc.destination.(type) { @@ -1245,11 +1341,11 @@ func (vc *vcursorImpl) keyForPlan(ctx context.Context, query string, buf io.Stri _, _ = buf.WriteString(query) } -func (vc *vcursorImpl) GetKeyspace() string { +func (vc *VCursorImpl) GetKeyspace() string { return vc.keyspace } -func (vc *vcursorImpl) ExecuteVSchema(ctx context.Context, keyspace string, vschemaDDL *sqlparser.AlterVschema) error { +func (vc *VCursorImpl) ExecuteVSchema(ctx context.Context, keyspace string, vschemaDDL *sqlparser.AlterVschema) error { srvVschema := vc.vm.GetCurrentSrvVschema() if srvVschema == nil { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "vschema not loaded") @@ -1270,7 +1366,7 @@ func (vc *vcursorImpl) ExecuteVSchema(ctx context.Context, keyspace string, vsch ksName = keyspace } if ksName == "" { - return errNoKeyspace + return ErrNoKeyspace } ks := srvVschema.Keyspaces[ksName] @@ -1284,43 +1380,43 @@ func (vc *vcursorImpl) ExecuteVSchema(ctx context.Context, keyspace string, vsch return vc.vm.UpdateVSchema(ctx, ksName, srvVschema) } -func (vc *vcursorImpl) MessageStream(ctx context.Context, rss []*srvtopo.ResolvedShard, tableName string, callback func(*sqltypes.Result) error) error { +func (vc *VCursorImpl) MessageStream(ctx context.Context, rss []*srvtopo.ResolvedShard, tableName string, callback func(*sqltypes.Result) error) error { atomic.AddUint64(&vc.logStats.ShardQueries, uint64(len(rss))) return vc.executor.ExecuteMessageStream(ctx, rss, tableName, callback) } -func (vc *vcursorImpl) VStream(ctx context.Context, rss []*srvtopo.ResolvedShard, filter *binlogdatapb.Filter, gtid string, callback func(evs []*binlogdatapb.VEvent) error) error { +func (vc *VCursorImpl) VStream(ctx context.Context, rss []*srvtopo.ResolvedShard, filter *binlogdatapb.Filter, gtid string, callback func(evs []*binlogdatapb.VEvent) error) error { return vc.executor.ExecuteVStream(ctx, rss, filter, gtid, callback) } -func (vc *vcursorImpl) ShowExec(ctx context.Context, command sqlparser.ShowCommandType, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { +func (vc *VCursorImpl) ShowExec(ctx context.Context, command sqlparser.ShowCommandType, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { switch command { case sqlparser.VitessReplicationStatus: - return vc.executor.showVitessReplicationStatus(ctx, filter) + return vc.executor.ShowVitessReplicationStatus(ctx, filter) case sqlparser.VitessShards: - return vc.executor.showShards(ctx, filter, vc.tabletType) + return vc.executor.ShowShards(ctx, filter, vc.tabletType) case sqlparser.VitessTablets: - return vc.executor.showTablets(filter) + return vc.executor.ShowTablets(filter) case sqlparser.VitessVariables: - return vc.executor.showVitessMetadata(ctx, filter) + return vc.executor.ShowVitessMetadata(ctx, filter) default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "bug: unexpected show command: %v", command) } } -func (vc *vcursorImpl) GetVSchema() *vindexes.VSchema { +func (vc *VCursorImpl) GetVSchema() *vindexes.VSchema { return vc.vschema } -func (vc *vcursorImpl) GetSrvVschema() *vschemapb.SrvVSchema { +func (vc *VCursorImpl) GetSrvVschema() *vschemapb.SrvVSchema { return vc.vm.GetCurrentSrvVschema() } -func (vc *vcursorImpl) SetExec(ctx context.Context, name string, value string) error { - return vc.executor.setVitessMetadata(ctx, name, value) +func (vc *VCursorImpl) SetExec(ctx context.Context, name string, value string) error { + return vc.executor.SetVitessMetadata(ctx, name, value) } -func (vc *vcursorImpl) ThrottleApp(ctx context.Context, throttledAppRule *topodatapb.ThrottledAppRule) (err error) { +func (vc *VCursorImpl) ThrottleApp(ctx context.Context, throttledAppRule *topodatapb.ThrottledAppRule) (err error) { if throttledAppRule == nil { return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "ThrottleApp: nil rule") } @@ -1378,147 +1474,60 @@ func (vc *vcursorImpl) ThrottleApp(ctx context.Context, throttledAppRule *topoda return err } -func (vc *vcursorImpl) CanUseSetVar() bool { - return vc.Environment().Parser().IsMySQL80AndAbove() && setVarEnabled +func (vc *VCursorImpl) CanUseSetVar() bool { + return vc.Environment().Parser().IsMySQL80AndAbove() && vc.config.SetVarEnabled } -func (vc *vcursorImpl) ReleaseLock(ctx context.Context) error { - return vc.executor.ReleaseLock(ctx, vc.safeSession) +func (vc *VCursorImpl) ReleaseLock(ctx context.Context) error { + return vc.executor.ReleaseLock(ctx, vc.SafeSession) } -func (vc *vcursorImpl) cloneWithAutocommitSession() *vcursorImpl { - safeSession := NewAutocommitSession(vc.safeSession.Session) - safeSession.logging = vc.safeSession.logging - return &vcursorImpl{ - safeSession: safeSession, - keyspace: vc.keyspace, - tabletType: vc.tabletType, - destination: vc.destination, - marginComments: vc.marginComments, - executor: vc.executor, - logStats: vc.logStats, - collation: vc.collation, - resolver: vc.resolver, - vschema: vc.vschema, - vm: vc.vm, - topoServer: vc.topoServer, - warnShardedOnly: vc.warnShardedOnly, - pv: vc.pv, - resultsObserver: vc.resultsObserver, - } -} - -func (vc *vcursorImpl) VExplainLogging() { - vc.safeSession.EnableLogging(vc.Environment().Parser()) +func (vc *VCursorImpl) VExplainLogging() { + vc.SafeSession.EnableLogging(vc.Environment().Parser()) } -func (vc *vcursorImpl) GetVExplainLogs() []engine.ExecuteEntry { - return vc.safeSession.logging.GetLogs() +func (vc *VCursorImpl) GetVExplainLogs() []engine.ExecuteEntry { + return vc.SafeSession.GetLogs() } -func (vc *vcursorImpl) FindRoutedShard(keyspace, shard string) (keyspaceName string, err error) { +func (vc *VCursorImpl) FindRoutedShard(keyspace, shard string) (keyspaceName string, err error) { return vc.vschema.FindRoutedShard(keyspace, shard) } -func (vc *vcursorImpl) IsViewsEnabled() bool { - return enableViews -} - -func (vc *vcursorImpl) GetUDV(name string) *querypb.BindVariable { - return vc.safeSession.GetUDV(name) +func (vc *VCursorImpl) IsViewsEnabled() bool { + return vc.config.EnableViews } -func (vc *vcursorImpl) PlanPrepareStatement(ctx context.Context, query string) (*engine.Plan, sqlparser.Statement, error) { - return vc.executor.planPrepareStmt(ctx, vc, query) +func (vc *VCursorImpl) GetUDV(name string) *querypb.BindVariable { + return vc.SafeSession.GetUDV(name) } -func (vc *vcursorImpl) ClearPrepareData(name string) { - delete(vc.safeSession.PrepareStatement, name) +func (vc *VCursorImpl) PlanPrepareStatement(ctx context.Context, query string) (*engine.Plan, sqlparser.Statement, error) { + return vc.executor.PlanPrepareStmt(ctx, vc, query) } -func (vc *vcursorImpl) StorePrepareData(stmtName string, prepareData *vtgatepb.PrepareData) { - vc.safeSession.StorePrepareData(stmtName, prepareData) +func (vc *VCursorImpl) ClearPrepareData(name string) { + delete(vc.SafeSession.PrepareStatement, name) } -func (vc *vcursorImpl) GetPrepareData(stmtName string) *vtgatepb.PrepareData { - return vc.safeSession.GetPrepareData(stmtName) +func (vc *VCursorImpl) StorePrepareData(stmtName string, prepareData *vtgatepb.PrepareData) { + vc.SafeSession.StorePrepareData(stmtName, prepareData) } -func (vc *vcursorImpl) GetWarmingReadsPercent() int { - return vc.warmingReadsPercent +func (vc *VCursorImpl) GetPrepareData(stmtName string) *vtgatepb.PrepareData { + return vc.SafeSession.GetPrepareData(stmtName) } -func (vc *vcursorImpl) GetWarmingReadsChannel() chan bool { - return vc.warmingReadsChannel -} - -func (vc *vcursorImpl) CloneForReplicaWarming(ctx context.Context) engine.VCursor { - callerId := callerid.EffectiveCallerIDFromContext(ctx) - immediateCallerId := callerid.ImmediateCallerIDFromContext(ctx) - - timedCtx, _ := context.WithTimeout(context.Background(), warmingReadsQueryTimeout) // nolint - clonedCtx := callerid.NewContext(timedCtx, callerId, immediateCallerId) - - v := &vcursorImpl{ - safeSession: NewAutocommitSession(vc.safeSession.Session), - keyspace: vc.keyspace, - tabletType: topodatapb.TabletType_REPLICA, - destination: vc.destination, - marginComments: vc.marginComments, - executor: vc.executor, - resolver: vc.resolver, - topoServer: vc.topoServer, - logStats: &logstats.LogStats{Ctx: clonedCtx}, - collation: vc.collation, - ignoreMaxMemoryRows: vc.ignoreMaxMemoryRows, - vschema: vc.vschema, - vm: vc.vm, - semTable: vc.semTable, - warnShardedOnly: vc.warnShardedOnly, - warnings: vc.warnings, - pv: vc.pv, - resultsObserver: nullResultsObserver{}, - } - - v.marginComments.Trailing += "/* warming read */" - - return v +func (vc *VCursorImpl) GetWarmingReadsPercent() int { + return vc.config.WarmingReadsPercent } -func (vc *vcursorImpl) CloneForMirroring(ctx context.Context) engine.VCursor { - callerId := callerid.EffectiveCallerIDFromContext(ctx) - immediateCallerId := callerid.ImmediateCallerIDFromContext(ctx) - - clonedCtx := callerid.NewContext(ctx, callerId, immediateCallerId) - - v := &vcursorImpl{ - safeSession: NewAutocommitSession(vc.safeSession.Session), - keyspace: vc.keyspace, - tabletType: vc.tabletType, - destination: vc.destination, - marginComments: vc.marginComments, - executor: vc.executor, - resolver: vc.resolver, - topoServer: vc.topoServer, - logStats: &logstats.LogStats{Ctx: clonedCtx}, - collation: vc.collation, - ignoreMaxMemoryRows: vc.ignoreMaxMemoryRows, - vschema: vc.vschema, - vm: vc.vm, - semTable: vc.semTable, - warnShardedOnly: vc.warnShardedOnly, - warnings: vc.warnings, - pv: vc.pv, - resultsObserver: nullResultsObserver{}, - } - - v.marginComments.Trailing += "/* mirror query */" - - return v +func (vc *VCursorImpl) GetWarmingReadsChannel() chan bool { + return vc.config.WarmingReadsChannel } // UpdateForeignKeyChecksState updates the foreign key checks state of the vcursor. -func (vc *vcursorImpl) UpdateForeignKeyChecksState(fkStateFromQuery *bool) { +func (vc *VCursorImpl) UpdateForeignKeyChecksState(fkStateFromQuery *bool) { // Initialize the state to unspecified. vc.fkChecksState = nil // If the query has a SET_VAR optimizer hint that explicitly sets the foreign key checks state, @@ -1528,17 +1537,36 @@ func (vc *vcursorImpl) UpdateForeignKeyChecksState(fkStateFromQuery *bool) { return } // If the query doesn't have anything, then we consult the session state. - vc.fkChecksState = vc.safeSession.ForeignKeyChecks() + vc.fkChecksState = vc.SafeSession.ForeignKeyChecks() } // GetForeignKeyChecksState gets the stored foreign key checks state in the vcursor. -func (vc *vcursorImpl) GetForeignKeyChecksState() *bool { +func (vc *VCursorImpl) GetForeignKeyChecksState() *bool { return vc.fkChecksState } // RecordMirrorStats is used to record stats about a mirror query. -func (vc *vcursorImpl) RecordMirrorStats(sourceExecTime, targetExecTime time.Duration, targetErr error) { +func (vc *VCursorImpl) RecordMirrorStats(sourceExecTime, targetExecTime time.Duration, targetErr error) { vc.logStats.MirrorSourceExecuteTime = sourceExecTime vc.logStats.MirrorTargetExecuteTime = targetExecTime vc.logStats.MirrorTargetError = targetErr } + +func (vc *VCursorImpl) GetMarginComments() sqlparser.MarginComments { + return vc.marginComments +} + +func (vc *VCursorImpl) CachePlan() bool { + return vc.SafeSession.CachePlan() +} + +func (vc *VCursorImpl) GetContextWithTimeOut(ctx context.Context) (context.Context, context.CancelFunc) { + if vc.queryTimeout == 0 { + return ctx, func() {} + } + return context.WithTimeout(ctx, vc.queryTimeout) +} + +func (vc *VCursorImpl) IgnoreMaxMemoryRows() bool { + return vc.ignoreMaxMemoryRows +} diff --git a/go/vt/vtgate/vcursor_impl_test.go b/go/vt/vtgate/executorcontext/vcursor_impl_test.go similarity index 60% rename from go/vt/vtgate/vcursor_impl_test.go rename to go/vt/vtgate/executorcontext/vcursor_impl_test.go index 95d9a18078d..16d2c03bf1c 100644 --- a/go/vt/vtgate/vcursor_impl_test.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl_test.go @@ -1,8 +1,23 @@ -package vtgate +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package executorcontext import ( "context" - "encoding/hex" "errors" "fmt" "strconv" @@ -12,10 +27,16 @@ import ( "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + "vitess.io/vitess/go/vt/vtenv" + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/vtgateservice" + "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/srvtopo" - "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -39,48 +60,6 @@ func (f fakeVSchemaOperator) UpdateVSchema(ctx context.Context, ksName string, v panic("implement me") } -type fakeTopoServer struct{} - -// GetTopoServer returns the full topo.Server instance. -func (f *fakeTopoServer) GetTopoServer() (*topo.Server, error) { - return nil, nil -} - -// GetSrvKeyspaceNames returns the list of keyspaces served in -// the provided cell. -func (f *fakeTopoServer) GetSrvKeyspaceNames(ctx context.Context, cell string, staleOK bool) ([]string, error) { - return []string{"ks1"}, nil -} - -// GetSrvKeyspace returns the SrvKeyspace for a cell/keyspace. -func (f *fakeTopoServer) GetSrvKeyspace(ctx context.Context, cell, keyspace string) (*topodatapb.SrvKeyspace, error) { - zeroHexBytes, _ := hex.DecodeString("") - eightyHexBytes, _ := hex.DecodeString("80") - ks := &topodatapb.SrvKeyspace{ - Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{ - { - ServedType: topodatapb.TabletType_PRIMARY, - ShardReferences: []*topodatapb.ShardReference{ - {Name: "-80", KeyRange: &topodatapb.KeyRange{Start: zeroHexBytes, End: eightyHexBytes}}, - {Name: "80-", KeyRange: &topodatapb.KeyRange{Start: eightyHexBytes, End: zeroHexBytes}}, - }, - }, - }, - } - return ks, nil -} - -func (f *fakeTopoServer) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodatapb.SrvKeyspace, error) bool) { - ks, err := f.GetSrvKeyspace(ctx, cell, keyspace) - callback(ks, err) -} - -// WatchSrvVSchema starts watching the SrvVSchema object for -// the provided cell. It will call the callback when -// a new value or an error occurs. -func (f *fakeTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { -} - func TestDestinationKeyspace(t *testing.T) { ks1 := &vindexes.Keyspace{ Name: "ks1", @@ -184,13 +163,17 @@ func TestDestinationKeyspace(t *testing.T) { }, { vschema: vschemaWith2KS, targetString: "", - expectedError: errNoKeyspace.Error(), + expectedError: ErrNoKeyspace.Error(), }} - r, _, _, _, _ := createExecutorEnv(t) for i, tc := range tests { t.Run(strconv.Itoa(i)+tc.targetString, func(t *testing.T) { - impl, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{TargetString: tc.targetString}), sqlparser.MarginComments{}, r, nil, &fakeVSchemaOperator{vschema: tc.vschema}, tc.vschema, nil, nil, false, querypb.ExecuteOptions_Gen4) + session := NewSafeSession(&vtgatepb.Session{TargetString: tc.targetString}) + impl, _ := NewVCursorImpl(session, sqlparser.MarginComments{}, nil, nil, + &fakeVSchemaOperator{vschema: tc.vschema}, tc.vschema, nil, nil, + fakeObserver{}, VCursorConfig{ + DefaultTabletType: topodatapb.TabletType_PRIMARY, + }) impl.vschema = tc.vschema dest, keyspace, tabletType, err := impl.TargetDestination(tc.qualifier) if tc.expectedError == "" { @@ -250,15 +233,15 @@ func TestSetTarget(t *testing.T) { expectedError: "can't execute the given command because you have an active transaction", }} - r, _, _, _, _ := createExecutorEnv(t) for i, tc := range tests { t.Run(fmt.Sprintf("%d#%s", i, tc.targetString), func(t *testing.T) { - vc, _ := newVCursorImpl(NewSafeSession(&vtgatepb.Session{InTransaction: true}), sqlparser.MarginComments{}, r, nil, &fakeVSchemaOperator{vschema: tc.vschema}, tc.vschema, nil, nil, false, querypb.ExecuteOptions_Gen4) + cfg := VCursorConfig{DefaultTabletType: topodatapb.TabletType_PRIMARY} + vc, _ := NewVCursorImpl(NewSafeSession(&vtgatepb.Session{InTransaction: true}), sqlparser.MarginComments{}, nil, nil, &fakeVSchemaOperator{vschema: tc.vschema}, tc.vschema, nil, nil, fakeObserver{}, cfg) vc.vschema = tc.vschema err := vc.SetTarget(tc.targetString) if tc.expectedError == "" { require.NoError(t, err) - require.Equal(t, vc.safeSession.TargetString, tc.targetString) + require.Equal(t, vc.SafeSession.TargetString, tc.targetString) } else { require.EqualError(t, err, tc.expectedError) } @@ -299,17 +282,20 @@ func TestKeyForPlan(t *testing.T) { expectedPlanPrefixKey: "ks1@replica+Collate:utf8mb4_0900_ai_ci+Query:SELECT 1", }} - r, _, _, _, _ := createExecutorEnv(t) for i, tc := range tests { t.Run(fmt.Sprintf("%d#%s", i, tc.targetString), func(t *testing.T) { ss := NewSafeSession(&vtgatepb.Session{InTransaction: false}) ss.SetTargetString(tc.targetString) - vc, err := newVCursorImpl(ss, sqlparser.MarginComments{}, r, nil, &fakeVSchemaOperator{vschema: tc.vschema}, tc.vschema, srvtopo.NewResolver(&fakeTopoServer{}, nil, ""), nil, false, querypb.ExecuteOptions_Gen4) + cfg := VCursorConfig{ + Collation: collations.CollationUtf8mb4ID, + DefaultTabletType: topodatapb.TabletType_PRIMARY, + } + vc, err := NewVCursorImpl(ss, sqlparser.MarginComments{}, &fakeExecutor{}, nil, &fakeVSchemaOperator{vschema: tc.vschema}, tc.vschema, srvtopo.NewResolver(&FakeTopoServer{}, nil, ""), nil, fakeObserver{}, cfg) require.NoError(t, err) vc.vschema = tc.vschema var buf strings.Builder - vc.keyForPlan(context.Background(), "SELECT 1", &buf) + vc.KeyForPlan(context.Background(), "SELECT 1", &buf) require.Equal(t, tc.expectedPlanPrefixKey, buf.String()) }) } @@ -327,8 +313,7 @@ func TestFirstSortedKeyspace(t *testing.T) { }, } - r, _, _, _, _ := createExecutorEnv(t) - vc, err := newVCursorImpl(NewSafeSession(nil), sqlparser.MarginComments{}, r, nil, &fakeVSchemaOperator{vschema: vschemaWith2KS}, vschemaWith2KS, srvtopo.NewResolver(&fakeTopoServer{}, nil, ""), nil, false, querypb.ExecuteOptions_Gen4) + vc, err := NewVCursorImpl(NewSafeSession(nil), sqlparser.MarginComments{}, nil, nil, &fakeVSchemaOperator{vschema: vschemaWith2KS}, vschemaWith2KS, srvtopo.NewResolver(&FakeTopoServer{}, nil, ""), nil, fakeObserver{}, VCursorConfig{}) require.NoError(t, err) ks, err := vc.FirstSortedKeyspace() require.NoError(t, err) @@ -338,13 +323,13 @@ func TestFirstSortedKeyspace(t *testing.T) { // TestSetExecQueryTimeout tests the SetExecQueryTimeout method. // Validates the timeout value is set based on override rule. func TestSetExecQueryTimeout(t *testing.T) { - executor, _, _, _, _ := createExecutorEnv(t) safeSession := NewSafeSession(nil) - vc, err := newVCursorImpl(safeSession, sqlparser.MarginComments{}, executor, nil, nil, &vindexes.VSchema{}, nil, nil, false, querypb.ExecuteOptions_Gen4) + vc, err := NewVCursorImpl(safeSession, sqlparser.MarginComments{}, nil, nil, nil, &vindexes.VSchema{}, nil, nil, fakeObserver{}, VCursorConfig{ + // flag timeout + QueryTimeout: 20, + }) require.NoError(t, err) - // flag timeout - queryTimeout = 20 vc.SetExecQueryTimeout(nil) require.Equal(t, 20*time.Millisecond, vc.queryTimeout) require.NotNil(t, safeSession.Options.Timeout) @@ -371,8 +356,8 @@ func TestSetExecQueryTimeout(t *testing.T) { require.NotNil(t, safeSession.Options.Timeout) require.EqualValues(t, 0, safeSession.Options.GetAuthoritativeTimeout()) - // reset - queryTimeout = 0 + // reset flag timeout + vc.config.QueryTimeout = 0 safeSession.SetQueryTimeout(0) vc.SetExecQueryTimeout(nil) require.Equal(t, 0*time.Millisecond, vc.queryTimeout) @@ -381,10 +366,9 @@ func TestSetExecQueryTimeout(t *testing.T) { } func TestRecordMirrorStats(t *testing.T) { - executor, _, _, _, _ := createExecutorEnv(t) safeSession := NewSafeSession(nil) logStats := logstats.NewLogStats(context.Background(), t.Name(), "select 1", "", nil) - vc, err := newVCursorImpl(safeSession, sqlparser.MarginComments{}, executor, logStats, nil, &vindexes.VSchema{}, nil, nil, false, querypb.ExecuteOptions_Gen4) + vc, err := NewVCursorImpl(safeSession, sqlparser.MarginComments{}, nil, logStats, nil, &vindexes.VSchema{}, nil, nil, fakeObserver{}, VCursorConfig{}) require.NoError(t, err) require.Zero(t, logStats.MirrorSourceExecuteTime) @@ -397,3 +381,113 @@ func TestRecordMirrorStats(t *testing.T) { require.Equal(t, 20*time.Millisecond, logStats.MirrorTargetExecuteTime) require.ErrorContains(t, logStats.MirrorTargetError, "test error") } + +type fakeExecutor struct{} + +func (f fakeExecutor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, session *SafeSession, s string, vars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver ResultsObserver) (qr *sqltypes.Result, errs []error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, observer ResultsObserver) []error { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) Commit(ctx context.Context, safeSession *SafeSession) error { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ExecuteMessageStream(ctx context.Context, rss []*srvtopo.ResolvedShard, name string, callback func(*sqltypes.Result) error) error { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ExecuteVStream(ctx context.Context, rss []*srvtopo.ResolvedShard, filter *binlogdatapb.Filter, gtid string, callback func(evs []*binlogdatapb.VEvent) error) error { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ReleaseLock(ctx context.Context, session *SafeSession) error { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ShowVitessReplicationStatus(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ShowShards(ctx context.Context, filter *sqlparser.ShowFilter, destTabletType topodatapb.TabletType) (*sqltypes.Result, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ShowTablets(filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ShowVitessMetadata(ctx context.Context, filter *sqlparser.ShowFilter) (*sqltypes.Result, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) SetVitessMetadata(ctx context.Context, name, value string) error { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) ParseDestinationTarget(targetString string) (string, topodatapb.TabletType, key.Destination, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) VSchema() *vindexes.VSchema { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) PlanPrepareStmt(ctx context.Context, vcursor *VCursorImpl, query string) (*engine.Plan, sqlparser.Statement, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) Environment() *vtenv.Environment { + return vtenv.NewTestEnv() +} + +func (f fakeExecutor) ReadTransaction(ctx context.Context, transactionID string) (*querypb.TransactionMetadata, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) UnresolvedTransactions(ctx context.Context, targets []*querypb.Target) ([]*querypb.TransactionMetadata, error) { + // TODO implement me + panic("implement me") +} + +func (f fakeExecutor) AddWarningCount(name string, value int64) { + // TODO implement me + panic("implement me") +} + +var _ iExecute = (*fakeExecutor)(nil) + +type fakeObserver struct{} + +func (f fakeObserver) Observe(*sqltypes.Result) { +} + +var _ ResultsObserver = (*fakeObserver)(nil) diff --git a/go/vt/vtgate/legacy_scatter_conn_test.go b/go/vt/vtgate/legacy_scatter_conn_test.go index 4512fc0724e..0d49e7b7bd9 100644 --- a/go/vt/vtgate/legacy_scatter_conn_test.go +++ b/go/vt/vtgate/legacy_scatter_conn_test.go @@ -26,6 +26,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -99,7 +101,7 @@ func TestLegacyExecuteFailOnAutocommit(t *testing.T) { }, Autocommit: false, } - _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}) + _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, econtext.NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}) err := vterrors.Aggregate(errs) require.Error(t, err) require.Contains(t, err.Error(), "in autocommit mode, transactionID should be zero but was: 123") @@ -123,7 +125,7 @@ func TestScatterConnExecuteMulti(t *testing.T) { } } - qr, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(nil), false /*autocommit*/, false, nullResultsObserver{}) + qr, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, econtext.NewSafeSession(nil), false /*autocommit*/, false, nullResultsObserver{}) return qr, vterrors.Aggregate(errs) }) } @@ -138,7 +140,7 @@ func TestScatterConnStreamExecuteMulti(t *testing.T) { bvs := make([]map[string]*querypb.BindVariable, len(rss)) qr := new(sqltypes.Result) var mu sync.Mutex - errors := sc.StreamExecuteMulti(ctx, nil, "query", rss, bvs, NewSafeSession(&vtgatepb.Session{InTransaction: true}), true /* autocommit */, func(r *sqltypes.Result) error { + errors := sc.StreamExecuteMulti(ctx, nil, "query", rss, bvs, econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}), true /* autocommit */, func(r *sqltypes.Result) error { mu.Lock() defer mu.Unlock() qr.AppendResult(r) @@ -280,7 +282,7 @@ func TestMaxMemoryRows(t *testing.T) { []key.Destination{key.DestinationShard("0"), key.DestinationShard("1")}) require.NoError(t, err) - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) queries := []*querypb.BoundQuery{{ Sql: "query1", BindVariables: map[string]*querypb.BindVariable{}, @@ -328,7 +330,7 @@ func TestLegaceHealthCheckFailsOnReservedConnections(t *testing.T) { res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa") - session := NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true}) destinations := []key.Destination{key.DestinationShard("0")} rss, _, err := res.ResolveDestinations(ctx, keyspace, topodatapb.TabletType_REPLICA, nil, destinations) require.NoError(t, err) @@ -346,12 +348,12 @@ func TestLegaceHealthCheckFailsOnReservedConnections(t *testing.T) { require.Error(t, vterrors.Aggregate(errs)) } -func executeOnShards(t *testing.T, ctx context.Context, res *srvtopo.Resolver, keyspace string, sc *ScatterConn, session *SafeSession, destinations []key.Destination) { +func executeOnShards(t *testing.T, ctx context.Context, res *srvtopo.Resolver, keyspace string, sc *ScatterConn, session *econtext.SafeSession, destinations []key.Destination) { t.Helper() require.Empty(t, executeOnShardsReturnsErr(t, ctx, res, keyspace, sc, session, destinations)) } -func executeOnShardsReturnsErr(t *testing.T, ctx context.Context, res *srvtopo.Resolver, keyspace string, sc *ScatterConn, session *SafeSession, destinations []key.Destination) error { +func executeOnShardsReturnsErr(t *testing.T, ctx context.Context, res *srvtopo.Resolver, keyspace string, sc *ScatterConn, session *econtext.SafeSession, destinations []key.Destination) error { t.Helper() rss, _, err := res.ResolveDestinations(ctx, keyspace, topodatapb.TabletType_REPLICA, nil, destinations) require.NoError(t, err) @@ -374,7 +376,7 @@ type recordingResultsObserver struct { recorded []*sqltypes.Result } -func (o *recordingResultsObserver) observe(result *sqltypes.Result) { +func (o *recordingResultsObserver) Observe(result *sqltypes.Result) { mu.Lock() o.recorded = append(o.recorded, result) mu.Unlock() @@ -429,7 +431,7 @@ func TestMultiExecs(t *testing.T) { observer := recordingResultsObserver{} - session := NewSafeSession(&vtgatepb.Session{}) + session := econtext.NewSafeSession(&vtgatepb.Session{}) _, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, &observer) require.NoError(t, vterrors.Aggregate(err)) if len(sbc0.Queries) == 0 || len(sbc1.Queries) == 0 { @@ -511,7 +513,7 @@ func TestScatterConnSingleDB(t *testing.T) { want := "multi-db transaction attempted" // TransactionMode_SINGLE in session - session := NewSafeSession(&vtgatepb.Session{InTransaction: true, TransactionMode: vtgatepb.TransactionMode_SINGLE}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true, TransactionMode: vtgatepb.TransactionMode_SINGLE}) queries := []*querypb.BoundQuery{{Sql: "query1"}} _, errors := sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) require.Empty(t, errors) @@ -521,7 +523,7 @@ func TestScatterConnSingleDB(t *testing.T) { // TransactionMode_SINGLE in txconn sc.txConn.mode = vtgatepb.TransactionMode_SINGLE - session = NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session = econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) require.Empty(t, errors) _, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{}) @@ -530,7 +532,7 @@ func TestScatterConnSingleDB(t *testing.T) { // TransactionMode_MULTI in txconn. Should not fail. sc.txConn.mode = vtgatepb.TransactionMode_MULTI - session = NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session = econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) _, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) require.Empty(t, errors) _, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{}) @@ -601,7 +603,7 @@ func TestReservePrequeries(t *testing.T) { res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa") - session := NewSafeSession(&vtgatepb.Session{ + session := econtext.NewSafeSession(&vtgatepb.Session{ InTransaction: false, InReservedConn: true, SystemVariables: map[string]string{ diff --git a/go/vt/vtgate/plan_execute.go b/go/vt/vtgate/plan_execute.go index 1c0915470ef..db7923c09f0 100644 --- a/go/vt/vtgate/plan_execute.go +++ b/go/vt/vtgate/plan_execute.go @@ -29,11 +29,12 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/vtgate/vtgateservice" ) -type planExec func(ctx context.Context, plan *engine.Plan, vc *vcursorImpl, bindVars map[string]*querypb.BindVariable, startTime time.Time) error +type planExec func(ctx context.Context, plan *engine.Plan, vc *econtext.VCursorImpl, bindVars map[string]*querypb.BindVariable, startTime time.Time) error type txResult func(sqlparser.StatementType, *sqltypes.Result) error var vschemaWaitTimeout = 30 * time.Second @@ -56,10 +57,12 @@ func waitForNewerVSchema(ctx context.Context, e *Executor, lastVSchemaCreated ti } } +const MaxBufferingRetries = 3 + func (e *Executor) newExecute( ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, - safeSession *SafeSession, + safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats, @@ -116,7 +119,7 @@ func (e *Executor) newExecute( } } - vcursor, err := newVCursorImpl(safeSession, comments, e, logStats, e.vm, vs, e.resolver.resolver, e.serv, e.warnShardedOnly, e.pv) + vcursor, err := econtext.NewVCursorImpl(safeSession, comments, e, logStats, e.vm, vs, e.resolver.resolver, e.serv, nullResultsObserver{}, e.vConfig) if err != nil { return err } @@ -146,10 +149,8 @@ func (e *Executor) newExecute( } // set the overall query timeout if it is not already set - if vcursor.queryTimeout > 0 && cancel == nil { - ctx, cancel = context.WithTimeout(ctx, vcursor.queryTimeout) - defer cancel() - } + ctx, cancel = vcursor.GetContextWithTimeOut(ctx) + defer cancel() result, err = e.handleTransactions(ctx, mysqlCtx, safeSession, plan, logStats, vcursor, stmt) if err != nil { @@ -225,10 +226,10 @@ func (e *Executor) newExecute( func (e *Executor) handleTransactions( ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, - safeSession *SafeSession, + safeSession *econtext.SafeSession, plan *engine.Plan, logStats *logstats.LogStats, - vcursor *vcursorImpl, + vcursor *econtext.VCursorImpl, stmt sqlparser.Statement, ) (*sqltypes.Result, error) { // We need to explicitly handle errors, and begin/commit/rollback, since these control transactions. Everything else @@ -247,19 +248,19 @@ func (e *Executor) handleTransactions( qr, err := e.handleSavepoint(ctx, safeSession, plan.Original, "Savepoint", logStats, func(_ string) (*sqltypes.Result, error) { // Safely to ignore as there is no transaction. return &sqltypes.Result{}, nil - }, vcursor.ignoreMaxMemoryRows) + }, vcursor.IgnoreMaxMemoryRows()) return qr, err case sqlparser.StmtSRollback: qr, err := e.handleSavepoint(ctx, safeSession, plan.Original, "Rollback Savepoint", logStats, func(query string) (*sqltypes.Result, error) { // Error as there is no transaction, so there is no savepoint that exists. return nil, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.SPDoesNotExist, "SAVEPOINT does not exist: %s", query) - }, vcursor.ignoreMaxMemoryRows) + }, vcursor.IgnoreMaxMemoryRows()) return qr, err case sqlparser.StmtRelease: qr, err := e.handleSavepoint(ctx, safeSession, plan.Original, "Release Savepoint", logStats, func(query string) (*sqltypes.Result, error) { // Error as there is no transaction, so there is no savepoint that exists. return nil, vterrors.NewErrorf(vtrpcpb.Code_NOT_FOUND, vterrors.SPDoesNotExist, "SAVEPOINT does not exist: %s", query) - }, vcursor.ignoreMaxMemoryRows) + }, vcursor.IgnoreMaxMemoryRows()) return qr, err case sqlparser.StmtKill: return e.handleKill(ctx, mysqlCtx, stmt, logStats) @@ -267,7 +268,7 @@ func (e *Executor) handleTransactions( return nil, nil } -func (e *Executor) startTxIfNecessary(ctx context.Context, safeSession *SafeSession) error { +func (e *Executor) startTxIfNecessary(ctx context.Context, safeSession *econtext.SafeSession) error { if !safeSession.Autocommit && !safeSession.InTransaction() { if err := e.txConn.Begin(ctx, safeSession, nil); err != nil { return err @@ -276,7 +277,7 @@ func (e *Executor) startTxIfNecessary(ctx context.Context, safeSession *SafeSess return nil } -func (e *Executor) insideTransaction(ctx context.Context, safeSession *SafeSession, logStats *logstats.LogStats, execPlan func() error) error { +func (e *Executor) insideTransaction(ctx context.Context, safeSession *econtext.SafeSession, logStats *logstats.LogStats, execPlan func() error) error { mustCommit := false if safeSession.Autocommit && !safeSession.InTransaction() { mustCommit = true @@ -320,9 +321,9 @@ func (e *Executor) insideTransaction(ctx context.Context, safeSession *SafeSessi func (e *Executor) executePlan( ctx context.Context, - safeSession *SafeSession, + safeSession *econtext.SafeSession, plan *engine.Plan, - vcursor *vcursorImpl, + vcursor *econtext.VCursorImpl, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats, execStart time.Time, @@ -342,7 +343,7 @@ func (e *Executor) executePlan( } // rollbackExecIfNeeded rollbacks the partial execution if earlier it was detected that it needs partial query execution to be rolled back. -func (e *Executor) rollbackExecIfNeeded(ctx context.Context, safeSession *SafeSession, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats, err error) error { +func (e *Executor) rollbackExecIfNeeded(ctx context.Context, safeSession *econtext.SafeSession, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats, err error) error { if safeSession.InTransaction() && safeSession.IsRollbackSet() { rErr := e.rollbackPartialExec(ctx, safeSession, bindVars, logStats) return vterrors.Wrap(err, rErr.Error()) @@ -353,7 +354,7 @@ func (e *Executor) rollbackExecIfNeeded(ctx context.Context, safeSession *SafeSe // rollbackPartialExec rollbacks to the savepoint or rollbacks transaction based on the value set on SafeSession.rollbackOnPartialExec. // Once, it is used the variable is reset. // If it fails to rollback to the previous savepoint then, the transaction is forced to be rolled back. -func (e *Executor) rollbackPartialExec(ctx context.Context, safeSession *SafeSession, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) error { +func (e *Executor) rollbackPartialExec(ctx context.Context, safeSession *econtext.SafeSession, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) error { var err error var errMsg strings.Builder @@ -367,8 +368,8 @@ func (e *Executor) rollbackPartialExec(ctx context.Context, safeSession *SafeSes } // needs to rollback only once. - rQuery := safeSession.rollbackOnPartialExec - if rQuery != txRollback { + rQuery := safeSession.GetRollbackOnPartialExec() + if rQuery != econtext.TxRollback { safeSession.SavepointRollback() _, _, err = e.execute(ctx, nil, safeSession, rQuery, bindVars, logStats) // If no error, the revert is successful with the savepoint. Notify the reason as error to the client. @@ -388,9 +389,9 @@ func (e *Executor) rollbackPartialExec(ctx context.Context, safeSession *SafeSes return vterrors.New(vtrpcpb.Code_ABORTED, errMsg.String()) } -func (e *Executor) setLogStats(logStats *logstats.LogStats, plan *engine.Plan, vcursor *vcursorImpl, execStart time.Time, err error, qr *sqltypes.Result) { +func (e *Executor) setLogStats(logStats *logstats.LogStats, plan *engine.Plan, vcursor *econtext.VCursorImpl, execStart time.Time, err error, qr *sqltypes.Result) { logStats.StmtType = plan.Type.String() - logStats.ActiveKeyspace = vcursor.keyspace + logStats.ActiveKeyspace = vcursor.GetKeyspace() logStats.TablesUsed = plan.TablesUsed logStats.TabletType = vcursor.TabletType().String() errCount := e.logExecutionEnd(logStats, execStart, plan, err, qr) diff --git a/go/vt/vtgate/planbuilder/collations_test.go b/go/vt/vtgate/planbuilder/collations_test.go index b393e186679..0595039e673 100644 --- a/go/vt/vtgate/planbuilder/collations_test.go +++ b/go/vt/vtgate/planbuilder/collations_test.go @@ -41,15 +41,13 @@ type collationTestCase struct { } func (tc *collationTestCase) run(t *testing.T) { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(t, "vschemas/schema.json", false), - SysVarEnabled: true, - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(t, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(t, err) - tc.addCollationsToSchema(vschemaWrapper) - plan, err := TestBuilder(tc.query, vschemaWrapper, vschemaWrapper.CurrentDb()) + tc.addCollationsToSchema(vw) + plan, err := TestBuilder(tc.query, vw, vw.CurrentDb()) require.NoError(t, err) tc.check(t, tc.collations, plan.Instructions) } diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index acba2caf937..ccbc9821170 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -74,17 +74,16 @@ func TestPlanTestSuite(t *testing.T) { func (s *planTestSuite) TestPlan() { defer utils.EnsureNoLeaks(s.T()) - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - TabletType_: topodatapb.TabletType_PRIMARY, - SysVarEnabled: true, - TestBuilder: TestBuilder, - Env: vtenv.NewTestEnv(), - } - s.addPKs(vschemaWrapper.V, "user", []string{"user", "music"}) - s.addPKsProvided(vschemaWrapper.V, "user", []string{"user_extra"}, []string{"id", "user_id"}) - s.addPKsProvided(vschemaWrapper.V, "ordering", []string{"order"}, []string{"oid", "region_id"}) - s.addPKsProvided(vschemaWrapper.V, "ordering", []string{"order_event"}, []string{"oid", "ename"}) + + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + s.addPKs(vschema, "user", []string{"user", "music"}) + s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"}) + s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"}) + s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"}) // You will notice that some tests expect user.Id instead of user.id. // This is because we now pre-create vindex columns in the symbol @@ -92,77 +91,73 @@ func (s *planTestSuite) TestPlan() { // the column is named as Id. This is to make sure that // column names are case-preserved, but treated as // case-insensitive even if they come from the vschema. - s.testFile("aggr_cases.json", vschemaWrapper, false) - s.testFile("dml_cases.json", vschemaWrapper, false) - s.testFile("from_cases.json", vschemaWrapper, false) - s.testFile("filter_cases.json", vschemaWrapper, false) - s.testFile("postprocess_cases.json", vschemaWrapper, false) - s.testFile("select_cases.json", vschemaWrapper, false) - s.testFile("symtab_cases.json", vschemaWrapper, false) - s.testFile("unsupported_cases.json", vschemaWrapper, false) - s.testFile("unknown_schema_cases.json", vschemaWrapper, false) - s.testFile("vindex_func_cases.json", vschemaWrapper, false) - s.testFile("wireup_cases.json", vschemaWrapper, false) - s.testFile("memory_sort_cases.json", vschemaWrapper, false) - s.testFile("use_cases.json", vschemaWrapper, false) - s.testFile("set_cases.json", vschemaWrapper, false) - s.testFile("union_cases.json", vschemaWrapper, false) - s.testFile("large_union_cases.json", vschemaWrapper, false) - s.testFile("transaction_cases.json", vschemaWrapper, false) - s.testFile("lock_cases.json", vschemaWrapper, false) - s.testFile("large_cases.json", vschemaWrapper, false) - s.testFile("ddl_cases_no_default_keyspace.json", vschemaWrapper, false) - s.testFile("flush_cases_no_default_keyspace.json", vschemaWrapper, false) - s.testFile("show_cases_no_default_keyspace.json", vschemaWrapper, false) - s.testFile("stream_cases.json", vschemaWrapper, false) - s.testFile("info_schema80_cases.json", vschemaWrapper, false) - s.testFile("reference_cases.json", vschemaWrapper, false) - s.testFile("vexplain_cases.json", vschemaWrapper, false) - s.testFile("misc_cases.json", vschemaWrapper, false) - s.testFile("cte_cases.json", vschemaWrapper, false) + s.testFile("aggr_cases.json", vw, false) + s.testFile("dml_cases.json", vw, false) + s.testFile("from_cases.json", vw, false) + s.testFile("filter_cases.json", vw, false) + s.testFile("postprocess_cases.json", vw, false) + s.testFile("select_cases.json", vw, false) + s.testFile("symtab_cases.json", vw, false) + s.testFile("unsupported_cases.json", vw, false) + s.testFile("unknown_schema_cases.json", vw, false) + s.testFile("vindex_func_cases.json", vw, false) + s.testFile("wireup_cases.json", vw, false) + s.testFile("memory_sort_cases.json", vw, false) + s.testFile("use_cases.json", vw, false) + s.testFile("set_cases.json", vw, false) + s.testFile("union_cases.json", vw, false) + s.testFile("large_union_cases.json", vw, false) + s.testFile("transaction_cases.json", vw, false) + s.testFile("lock_cases.json", vw, false) + s.testFile("large_cases.json", vw, false) + s.testFile("ddl_cases_no_default_keyspace.json", vw, false) + s.testFile("flush_cases_no_default_keyspace.json", vw, false) + s.testFile("show_cases_no_default_keyspace.json", vw, false) + s.testFile("stream_cases.json", vw, false) + s.testFile("info_schema80_cases.json", vw, false) + s.testFile("reference_cases.json", vw, false) + s.testFile("vexplain_cases.json", vw, false) + s.testFile("misc_cases.json", vw, false) + s.testFile("cte_cases.json", vw, false) } // TestForeignKeyPlanning tests the planning of foreign keys in a managed mode by Vitess. func (s *planTestSuite) TestForeignKeyPlanning() { + env := vtenv.NewTestEnv() vschema := loadSchema(s.T(), "vschemas/schema.json", true) - s.setFks(vschema) - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: vschema, - TestBuilder: TestBuilder, - Env: vtenv.NewTestEnv(), - } + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("foreignkey_cases.json", vschemaWrapper, false) + s.setFks(vschema) + s.testFile("foreignkey_cases.json", vw, false) } // TestForeignKeyChecksOn tests the planning when the session variable for foreign_key_checks is set to ON. func (s *planTestSuite) TestForeignKeyChecksOn() { + env := vtenv.NewTestEnv() vschema := loadSchema(s.T(), "vschemas/schema.json", true) - s.setFks(vschema) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + fkChecksState := true - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: vschema, - TestBuilder: TestBuilder, - ForeignKeyChecksState: &fkChecksState, - Env: vtenv.NewTestEnv(), - } + vw.ForeignKeyChecksState = &fkChecksState - s.testFile("foreignkey_checks_on_cases.json", vschemaWrapper, false) + s.setFks(vschema) + s.testFile("foreignkey_checks_on_cases.json", vw, false) } // TestForeignKeyChecksOff tests the planning when the session variable for foreign_key_checks is set to OFF. func (s *planTestSuite) TestForeignKeyChecksOff() { + env := vtenv.NewTestEnv() vschema := loadSchema(s.T(), "vschemas/schema.json", true) - s.setFks(vschema) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + fkChecksState := false - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: vschema, - TestBuilder: TestBuilder, - ForeignKeyChecksState: &fkChecksState, - Env: vtenv.NewTestEnv(), - } + vw.ForeignKeyChecksState = &fkChecksState - s.testFile("foreignkey_checks_off_cases.json", vschemaWrapper, false) + s.setFks(vschema) + s.testFile("foreignkey_checks_off_cases.json", vw, false) } func (s *planTestSuite) setFks(vschema *vindexes.VSchema) { @@ -266,120 +261,127 @@ func (s *planTestSuite) TestSystemTables57() { MySQLServerVersion: "5.7.9", }) require.NoError(s.T(), err) - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Env: env, - } - s.testFile("info_schema57_cases.json", vschemaWrapper, false) + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + s.testFile("info_schema57_cases.json", vw, false) } func (s *planTestSuite) TestSysVarSetDisabled() { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - SysVarEnabled: false, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("set_sysvar_disabled_cases.json", vschemaWrapper, false) + vw.SysVarEnabled = false + + s.testFile("set_sysvar_disabled_cases.json", vw, false) } func (s *planTestSuite) TestViews() { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - EnableViews: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("view_cases.json", vschemaWrapper, false) + vw.EnableViews = true + + s.testFile("view_cases.json", vw, false) } func (s *planTestSuite) TestOne() { reset := operators.EnableDebugPrinting() defer reset() - lv := loadSchema(s.T(), "vschemas/schema.json", true) - s.setFks(lv) - s.addPKs(lv, "user", []string{"user", "music"}) - s.addPKs(lv, "main", []string{"unsharded"}) - s.addPKsProvided(lv, "user", []string{"user_extra"}, []string{"id", "user_id"}) - s.addPKsProvided(lv, "ordering", []string{"order"}, []string{"oid", "region_id"}) - s.addPKsProvided(lv, "ordering", []string{"order_event"}, []string{"oid", "ename"}) - vschema := &vschemawrapper.VSchemaWrapper{ - V: lv, - TestBuilder: TestBuilder, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + s.setFks(vschema) + s.addPKs(vschema, "user", []string{"user", "music"}) + s.addPKs(vschema, "main", []string{"unsharded"}) + s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"}) + s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"}) + s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"}) - s.testFile("onecase.json", vschema, false) + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestOneTPCC() { reset := operators.EnableDebugPrinting() defer reset() - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/tpcc_schema.json", true), - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/tpcc_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("onecase.json", vschema, false) + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestOneWithMainAsDefault() { reset := operators.EnableDebugPrinting() defer reset() - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "main", - Sharded: false, - }, - Env: vtenv.NewTestEnv(), - } - s.testFile("onecase.json", vschema, false) + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Vcursor.SetTarget("main") + vw.Keyspace = &vindexes.Keyspace{Name: "main"} + + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestOneWithSecondUserAsDefault() { reset := operators.EnableDebugPrinting() defer reset() - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "second_user", - Sharded: true, - }, - Env: vtenv.NewTestEnv(), + + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Vcursor.SetTarget("second_user") + vw.Keyspace = &vindexes.Keyspace{ + Name: "second_user", + Sharded: true, } - s.testFile("onecase.json", vschema, false) + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestOneWithUserAsDefault() { reset := operators.EnableDebugPrinting() defer reset() - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "user", - Sharded: true, - }, - Env: vtenv.NewTestEnv(), + + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Vcursor.SetTarget("user") + vw.Keyspace = &vindexes.Keyspace{ + Name: "user", + Sharded: true, } - s.testFile("onecase.json", vschema, false) + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestOneWithTPCHVSchema() { reset := operators.EnableDebugPrinting() defer reset() - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/tpch_schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } - s.testFile("onecase.json", vschema, false) + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestOneWith57Version() { @@ -390,52 +392,47 @@ func (s *planTestSuite) TestOneWith57Version() { MySQLServerVersion: "5.7.9", }) require.NoError(s.T(), err) - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Env: env, - } + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("onecase.json", vschema, false) + s.testFile("onecase.json", vw, false) } func (s *planTestSuite) TestRubyOnRailsQueries() { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/rails_schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/rails_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("rails_cases.json", vschemaWrapper, false) + s.testFile("rails_cases.json", vw, false) } func (s *planTestSuite) TestOLTP() { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/oltp_schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/oltp_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("oltp_cases.json", vschemaWrapper, false) + s.testFile("oltp_cases.json", vw, false) } func (s *planTestSuite) TestTPCC() { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/tpcc_schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/tpcc_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("tpcc_cases.json", vschemaWrapper, false) + s.testFile("tpcc_cases.json", vw, false) } func (s *planTestSuite) TestTPCH() { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/tpch_schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/tpch_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("tpch_cases.json", vschemaWrapper, false) + s.testFile("tpch_cases.json", vw, false) } func BenchmarkOLTP(b *testing.B) { @@ -451,15 +448,14 @@ func BenchmarkTPCH(b *testing.B) { } func benchmarkWorkload(b *testing.B, name string) { - vschemaWrapper := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(b, "vschemas/"+name+"_schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(b, "vschemas/"+name+"_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(b, err) testCases := readJSONTests(name + "_cases.json") b.ResetTimer() - benchmarkPlanner(b, Gen4, testCases, vschemaWrapper) + benchmarkPlanner(b, Gen4, testCases, vw) } func (s *planTestSuite) TestBypassPlanningShardTargetFromFile() { @@ -478,35 +474,33 @@ func (s *planTestSuite) TestBypassPlanningShardTargetFromFile() { } func (s *planTestSuite) TestBypassPlanningKeyrangeTargetFromFile() { + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + keyRange, _ := key.ParseShardingSpec("-") + vw.Dest = key.DestinationExactKeyRange{KeyRange: keyRange[0]} - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "main", - Sharded: false, - }, - TabletType_: topodatapb.TabletType_PRIMARY, - Dest: key.DestinationExactKeyRange{KeyRange: keyRange[0]}, - Env: vtenv.NewTestEnv(), - } + vw.Vcursor.SetTarget("main") + vw.Keyspace = &vindexes.Keyspace{Name: "main"} - s.testFile("bypass_keyrange_cases.json", vschema, false) + s.testFile("bypass_keyrange_cases.json", vw, false) } func (s *planTestSuite) TestWithDefaultKeyspaceFromFile() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // We are testing this separately so we can set a default keyspace - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "main", - Sharded: false, - }, - TabletType_: topodatapb.TabletType_PRIMARY, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Vcursor.SetTarget("main") + vw.Keyspace = &vindexes.Keyspace{Name: "main"} + ts := memorytopo.NewServer(ctx, "cell1") ts.CreateKeyspace(ctx, "main", &topodatapb.Keyspace{}) ts.CreateKeyspace(ctx, "user", &topodatapb.Keyspace{}) @@ -521,97 +515,92 @@ func (s *planTestSuite) TestWithDefaultKeyspaceFromFile() { }) require.True(s.T(), created) - s.testFile("alterVschema_cases.json", vschema, false) - s.testFile("ddl_cases.json", vschema, false) - s.testFile("migration_cases.json", vschema, false) - s.testFile("flush_cases.json", vschema, false) - s.testFile("show_cases.json", vschema, false) - s.testFile("call_cases.json", vschema, false) + s.testFile("alterVschema_cases.json", vw, false) + s.testFile("ddl_cases.json", vw, false) + s.testFile("migration_cases.json", vw, false) + s.testFile("flush_cases.json", vw, false) + s.testFile("show_cases.json", vw, false) + s.testFile("call_cases.json", vw, false) } func (s *planTestSuite) TestWithDefaultKeyspaceFromFileSharded() { // We are testing this separately so we can set a default keyspace - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "second_user", - Sharded: true, - }, - TabletType_: topodatapb.TabletType_PRIMARY, - Env: vtenv.NewTestEnv(), + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Vcursor.SetTarget("second_user") + vw.Keyspace = &vindexes.Keyspace{ + Name: "second_user", + Sharded: true, } - s.testFile("select_cases_with_default.json", vschema, false) + s.testFile("select_cases_with_default.json", vw, false) } func (s *planTestSuite) TestWithUserDefaultKeyspaceFromFileSharded() { // We are testing this separately so we can set a default keyspace - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "user", - Sharded: true, - }, - TabletType_: topodatapb.TabletType_PRIMARY, - Env: vtenv.NewTestEnv(), + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Vcursor.SetTarget("user") + vw.Keyspace = &vindexes.Keyspace{ + Name: "user", + Sharded: true, } - s.testFile("select_cases_with_user_as_default.json", vschema, false) - s.testFile("dml_cases_with_user_as_default.json", vschema, false) + s.testFile("select_cases_with_user_as_default.json", vw, false) + s.testFile("dml_cases_with_user_as_default.json", vw, false) } func (s *planTestSuite) TestWithSystemSchemaAsDefaultKeyspace() { // We are testing this separately so we can set a default keyspace - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{Name: "information_schema"}, - TabletType_: topodatapb.TabletType_PRIMARY, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + vw.Keyspace = &vindexes.Keyspace{Name: "information_schema"} - s.testFile("sysschema_default.json", vschema, false) + s.testFile("sysschema_default.json", vw, false) } func (s *planTestSuite) TestOtherPlanningFromFile() { // We are testing this separately so we can set a default keyspace - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/schema.json", true), - Keyspace: &vindexes.Keyspace{ - Name: "main", - Sharded: false, - }, - TabletType_: topodatapb.TabletType_PRIMARY, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("other_read_cases.json", vschema, false) - s.testFile("other_admin_cases.json", vschema, false) + vw.Vcursor.SetTarget("main") + vw.Keyspace = &vindexes.Keyspace{Name: "main"} + + s.testFile("other_read_cases.json", vw, false) + s.testFile("other_admin_cases.json", vw, false) } func (s *planTestSuite) TestMirrorPlanning() { - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/mirror_schema.json", true), - TabletType_: topodatapb.TabletType_PRIMARY, - SysVarEnabled: true, - TestBuilder: TestBuilder, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/mirror_schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) - s.testFile("mirror_cases.json", vschema, false) + s.testFile("mirror_cases.json", vw, false) } func (s *planTestSuite) TestOneMirror() { reset := operators.EnableDebugPrinting() defer reset() - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(s.T(), "vschemas/mirror_schema.json", true), - TabletType_: topodatapb.TabletType_PRIMARY, - SysVarEnabled: true, - TestBuilder: TestBuilder, - Env: vtenv.NewTestEnv(), - } - s.testFile("onecase.json", vschema, false) + env := vtenv.NewTestEnv() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(s.T(), err) + + s.testFile("onecase.json", vw, false) } func loadSchema(t testing.TB, filename string, setCollation bool) *vindexes.VSchema { @@ -784,30 +773,29 @@ func locateFile(name string) string { var benchMarkFiles = []string{"from_cases.json", "filter_cases.json", "large_cases.json", "aggr_cases.json", "select_cases.json", "union_cases.json"} func BenchmarkPlanner(b *testing.B) { - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(b, "vschemas/schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(b, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(b, err) + for _, filename := range benchMarkFiles { testCases := readJSONTests(filename) b.Run(filename+"-gen4", func(b *testing.B) { - benchmarkPlanner(b, Gen4, testCases, vschema) + benchmarkPlanner(b, Gen4, testCases, vw) }) } } func BenchmarkSemAnalysis(b *testing.B) { - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(b, "vschemas/schema.json", true), - SysVarEnabled: true, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(b, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(b, err) for i := 0; i < b.N; i++ { for _, filename := range benchMarkFiles { for _, tc := range readJSONTests(filename) { - exerciseAnalyzer(tc.Query, vschema.CurrentDb(), vschema) + exerciseAnalyzer(tc.Query, vw.CurrentDb(), vw) } } } @@ -832,12 +820,10 @@ func exerciseAnalyzer(query, database string, s semantics.SchemaInformation) { } func BenchmarkSelectVsDML(b *testing.B) { - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(b, "vschemas/schema.json", true), - SysVarEnabled: true, - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(b, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(b, err) dmlCases := readJSONTests("dml_cases.json") selectCases := readJSONTests("select_cases.json") @@ -851,40 +837,33 @@ func BenchmarkSelectVsDML(b *testing.B) { }) b.Run("DML (random sample, N=32)", func(b *testing.B) { - benchmarkPlanner(b, Gen4, dmlCases[:32], vschema) + benchmarkPlanner(b, Gen4, dmlCases[:32], vw) }) b.Run("Select (random sample, N=32)", func(b *testing.B) { - benchmarkPlanner(b, Gen4, selectCases[:32], vschema) + benchmarkPlanner(b, Gen4, selectCases[:32], vw) }) } func BenchmarkBaselineVsMirrored(b *testing.B) { + env := vtenv.NewTestEnv() baseline := loadSchema(b, "vschemas/mirror_schema.json", true) baseline.MirrorRules = map[string]*vindexes.MirrorRule{} - baselineVschema := &vschemawrapper.VSchemaWrapper{ - V: baseline, - SysVarEnabled: true, - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + bvw, err := vschemawrapper.NewVschemaWrapper(env, baseline, TestBuilder) + require.NoError(b, err) mirroredSchema := loadSchema(b, "vschemas/mirror_schema.json", true) - mirroredVschema := &vschemawrapper.VSchemaWrapper{ - V: mirroredSchema, - SysVarEnabled: true, - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + mvw, err := vschemawrapper.NewVschemaWrapper(env, mirroredSchema, TestBuilder) + require.NoError(b, err) cases := readJSONTests("mirror_cases.json") b.Run("Baseline", func(b *testing.B) { - benchmarkPlanner(b, Gen4, cases, baselineVschema) + benchmarkPlanner(b, Gen4, cases, bvw) }) b.Run("Mirrored", func(b *testing.B) { - benchmarkPlanner(b, Gen4, cases, mirroredVschema) + benchmarkPlanner(b, Gen4, cases, mvw) }) } diff --git a/go/vt/vtgate/planbuilder/show_test.go b/go/vt/vtgate/planbuilder/show_test.go index bfdb9a623a0..c3651aaa1cd 100644 --- a/go/vt/vtgate/planbuilder/show_test.go +++ b/go/vt/vtgate/planbuilder/show_test.go @@ -32,10 +32,13 @@ import ( ) func TestBuildDBPlan(t *testing.T) { - vschema := &vschemawrapper.VSchemaWrapper{ - Keyspace: &vindexes.Keyspace{Name: "main"}, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(t, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(t, err) + + vw.Vcursor.SetTarget("main") + vw.Keyspace = &vindexes.Keyspace{Name: "main"} testCases := []struct { query string @@ -54,7 +57,7 @@ func TestBuildDBPlan(t *testing.T) { require.NoError(t, err) show := parserOut.(*sqlparser.Show) - primitive, err := buildDBPlan(show.Internal.(*sqlparser.ShowBasic), vschema) + primitive, err := buildDBPlan(show.Internal.(*sqlparser.ShowBasic), vw) require.NoError(t, err) result, err := primitive.TryExecute(context.Background(), nil, nil, false) diff --git a/go/vt/vtgate/planbuilder/simplifier_test.go b/go/vt/vtgate/planbuilder/simplifier_test.go index c4b9fd71174..dce21b3e175 100644 --- a/go/vt/vtgate/planbuilder/simplifier_test.go +++ b/go/vt/vtgate/planbuilder/simplifier_test.go @@ -38,21 +38,21 @@ func TestSimplifyBuggyQuery(t *testing.T) { query := "select distinct count(distinct a), count(distinct 4) from user left join unsharded on 0 limit 5" // select 0 from unsharded union select 0 from `user` union select 0 from unsharded // select 0 from unsharded union (select 0 from `user` union select 0 from unsharded) - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(t, "vschemas/schema.json", true), - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(t, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(t, err) + stmt, reserved, err := sqlparser.NewTestParser().Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.SelectStatement), - vschema.CurrentDb(), - vschema, - keepSameError(query, reservedVars, vschema, rewritten.BindVarNeeds), + vw.CurrentDb(), + vw, + keepSameError(query, reservedVars, vw, rewritten.BindVarNeeds), ) fmt.Println(sqlparser.String(simplified)) @@ -61,21 +61,22 @@ func TestSimplifyBuggyQuery(t *testing.T) { func TestSimplifyPanic(t *testing.T) { t.Skip("not needed to run") query := "(select id from unsharded union select id from unsharded_auto) union (select id from unsharded_auto union select name from unsharded)" - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(t, "vschemas/schema.json", true), - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + + env := vtenv.NewTestEnv() + vschema := loadSchema(t, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(t, err) + stmt, reserved, err := sqlparser.NewTestParser().Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.SelectStatement), - vschema.CurrentDb(), - vschema, - keepPanicking(query, reservedVars, vschema, rewritten.BindVarNeeds), + vw.CurrentDb(), + vw, + keepPanicking(query, reservedVars, vw, rewritten.BindVarNeeds), ) fmt.Println(sqlparser.String(simplified)) @@ -83,11 +84,11 @@ func TestSimplifyPanic(t *testing.T) { func TestUnsupportedFile(t *testing.T) { t.Skip("run manually to see if any queries can be simplified") - vschema := &vschemawrapper.VSchemaWrapper{ - V: loadSchema(t, "vschemas/schema.json", true), - Version: Gen4, - Env: vtenv.NewTestEnv(), - } + env := vtenv.NewTestEnv() + vschema := loadSchema(t, "vschemas/schema.json", true) + vw, err := vschemawrapper.NewVschemaWrapper(env, vschema, TestBuilder) + require.NoError(t, err) + fmt.Println(vschema) for _, tcase := range readJSONTests("unsupported_cases.txt") { t.Run(tcase.Query, func(t *testing.T) { @@ -99,11 +100,10 @@ func TestUnsupportedFile(t *testing.T) { t.Skip() return } - rewritten, err := sqlparser.RewriteAST(stmt, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + rewritten, err := sqlparser.RewriteAST(stmt, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) if err != nil { t.Skip() } - vschema.CurrentDb() reservedVars := sqlparser.NewReservedVars("vtg", reserved) ast := rewritten.AST @@ -111,9 +111,9 @@ func TestUnsupportedFile(t *testing.T) { stmt, _, _ = sqlparser.NewTestParser().Parse2(tcase.Query) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.SelectStatement), - vschema.CurrentDb(), - vschema, - keepSameError(tcase.Query, reservedVars, vschema, rewritten.BindVarNeeds), + vw.CurrentDb(), + vw, + keepSameError(tcase.Query, reservedVars, vw, rewritten.BindVarNeeds), ) if simplified == nil { diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index ea0ae7b0e83..6e2cf9ad8ba 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -41,6 +41,7 @@ import ( "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/queryservice" ) @@ -71,13 +72,10 @@ type shardActionFunc func(rs *srvtopo.ResolvedShard, i int) error type shardActionTransactionFunc func(rs *srvtopo.ResolvedShard, i int, shardActionInfo *shardActionInfo) (*shardActionInfo, error) type ( - resultsObserver interface { - observe(*sqltypes.Result) - } nullResultsObserver struct{} ) -func (nullResultsObserver) observe(*sqltypes.Result) {} +func (nullResultsObserver) Observe(*sqltypes.Result) {} // NewScatterConn creates a new ScatterConn. func NewScatterConn(statsName string, txConn *TxConn, gw *TabletGateway) *ScatterConn { @@ -106,7 +104,7 @@ func (stc *ScatterConn) startAction(name string, target *querypb.Target) (time.T return startTime, statsKey } -func (stc *ScatterConn) endAction(startTime time.Time, allErrors *concurrency.AllErrorRecorder, statsKey []string, err *error, session *SafeSession) { +func (stc *ScatterConn) endAction(startTime time.Time, allErrors *concurrency.AllErrorRecorder, statsKey []string, err *error, session *econtext.SafeSession) { if *err != nil { allErrors.RecordError(*err) // Don't increment the error counter for duplicate @@ -150,10 +148,10 @@ func (stc *ScatterConn) ExecuteMultiShard( primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, - session *SafeSession, + session *econtext.SafeSession, autocommit bool, ignoreMaxMemoryRows bool, - resultsObserver resultsObserver, + resultsObserver econtext.ResultsObserver, ) (qr *sqltypes.Result, errs []error) { if len(rss) != len(queries) { @@ -164,7 +162,7 @@ func (stc *ScatterConn) ExecuteMultiShard( var mu sync.Mutex qr = new(sqltypes.Result) - if session.InLockSession() && session.TriggerLockHeartBeat() { + if session.InLockSession() && triggerLockHeartBeat(session) { go stc.runLockQuery(ctx, session) } @@ -260,7 +258,7 @@ func (stc *ScatterConn) ExecuteMultiShard( default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unexpected actionNeeded on query execution: %v", info.actionNeeded) } - session.logging.log(primitive, rs.Target, rs.Gateway, queries[i].Sql, info.actionNeeded == begin || info.actionNeeded == reserveBegin, queries[i].BindVariables) + session.Log(primitive, rs.Target, rs.Gateway, queries[i].Sql, info.actionNeeded == begin || info.actionNeeded == reserveBegin, queries[i].BindVariables) // We need to new shard info irrespective of the error. newInfo := info.updateTransactionAndReservedID(transactionID, reservedID, alias, innerqr) @@ -271,7 +269,7 @@ func (stc *ScatterConn) ExecuteMultiShard( defer mu.Unlock() if innerqr != nil { - resultsObserver.observe(innerqr) + resultsObserver.Observe(innerqr) } // Don't append more rows if row count is exceeded. @@ -289,7 +287,13 @@ func (stc *ScatterConn) ExecuteMultiShard( return qr, allErrors.GetErrors() } -func (stc *ScatterConn) runLockQuery(ctx context.Context, session *SafeSession) { +func triggerLockHeartBeat(session *econtext.SafeSession) bool { + now := time.Now().Unix() + lastHeartbeat := session.GetLockHeartbeat() + return now-lastHeartbeat >= int64(lockHeartbeatTime.Seconds()) +} + +func (stc *ScatterConn) runLockQuery(ctx context.Context, session *econtext.SafeSession) { rs := &srvtopo.ResolvedShard{Target: session.LockSession.Target, Gateway: stc.gateway} query := &querypb.BoundQuery{Sql: "select 1", BindVariables: nil} _, lockErr := stc.ExecuteLock(ctx, rs, query, session, sqlparser.IsUsedLock) @@ -298,7 +302,7 @@ func (stc *ScatterConn) runLockQuery(ctx context.Context, session *SafeSession) } } -func checkAndResetShardSession(info *shardActionInfo, err error, session *SafeSession, target *querypb.Target) reset { +func checkAndResetShardSession(info *shardActionInfo, err error, session *econtext.SafeSession, target *querypb.Target) reset { retry := none if info.reservedID != 0 && info.transactionID == 0 { if wasConnectionClosed(err) { @@ -314,7 +318,7 @@ func checkAndResetShardSession(info *shardActionInfo, err error, session *SafeSe return retry } -func getQueryService(ctx context.Context, rs *srvtopo.ResolvedShard, info *shardActionInfo, session *SafeSession, skipReset bool) (queryservice.QueryService, error) { +func getQueryService(ctx context.Context, rs *srvtopo.ResolvedShard, info *shardActionInfo, session *econtext.SafeSession, skipReset bool) (queryservice.QueryService, error) { if info.alias == nil { return rs.Gateway, nil } @@ -365,18 +369,18 @@ func (stc *ScatterConn) StreamExecuteMulti( query string, rss []*srvtopo.ResolvedShard, bindVars []map[string]*querypb.BindVariable, - session *SafeSession, + session *econtext.SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, - resultsObserver resultsObserver, + resultsObserver econtext.ResultsObserver, ) []error { - if session.InLockSession() && session.TriggerLockHeartBeat() { + if session.InLockSession() && triggerLockHeartBeat(session) { go stc.runLockQuery(ctx, session) } observedCallback := func(reply *sqltypes.Result) error { if reply != nil { - resultsObserver.observe(reply) + resultsObserver.Observe(reply) } return callback(reply) } @@ -469,7 +473,7 @@ func (stc *ScatterConn) StreamExecuteMulti( default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unexpected actionNeeded on query execution: %v", info.actionNeeded) } - session.logging.log(primitive, rs.Target, rs.Gateway, query, info.actionNeeded == begin || info.actionNeeded == reserveBegin, bindVars[i]) + session.Log(primitive, rs.Target, rs.Gateway, query, info.actionNeeded == begin || info.actionNeeded == reserveBegin, bindVars[i]) // We need the new shard info irrespective of the error. newInfo := info.updateTransactionAndReservedID(transactionID, reservedID, alias, nil) @@ -604,7 +608,7 @@ func (stc *ScatterConn) multiGo( startTime, statsKey := stc.startAction(name, rs.Target) // Send a dummy session. // TODO(sougou): plumb a real session through this call. - defer stc.endAction(startTime, allErrors, statsKey, &err, NewSafeSession(nil)) + defer stc.endAction(startTime, allErrors, statsKey, &err, econtext.NewSafeSession(nil)) err = action(rs, i) } @@ -646,7 +650,7 @@ func (stc *ScatterConn) multiGoTransaction( ctx context.Context, name string, rss []*srvtopo.ResolvedShard, - session *SafeSession, + session *econtext.SafeSession, autocommit bool, action shardActionTransactionFunc, ) (allErrors *concurrency.AllErrorRecorder) { @@ -730,7 +734,7 @@ func (stc *ScatterConn) multiGoTransaction( // It returns an error recorder in which each shard error is recorded positionally, // i.e. if rss[2] had an error, then the error recorder will store that error // in the second position. -func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { +func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *econtext.SafeSession, lockFuncType sqlparser.LockingFuncType) (*sqltypes.Result, error) { var ( qr *sqltypes.Result @@ -833,7 +837,7 @@ func requireNewQS(err error, target *querypb.Target) bool { } // actionInfo looks at the current session, and returns information about what needs to be done for this tablet -func actionInfo(ctx context.Context, target *querypb.Target, session *SafeSession, autocommit bool, txMode vtgatepb.TransactionMode) (*shardActionInfo, *vtgatepb.Session_ShardSession, error) { +func actionInfo(ctx context.Context, target *querypb.Target, session *econtext.SafeSession, autocommit bool, txMode vtgatepb.TransactionMode) (*shardActionInfo, *vtgatepb.Session_ShardSession, error) { if !(session.InTransaction() || session.InReservedConn()) { return &shardActionInfo{}, nil, nil } @@ -876,7 +880,7 @@ func actionInfo(ctx context.Context, target *querypb.Target, session *SafeSessio } // lockInfo looks at the current session, and returns information about what needs to be done for this tablet -func lockInfo(target *querypb.Target, session *SafeSession, lockFuncType sqlparser.LockingFuncType) (*shardActionInfo, error) { +func lockInfo(target *querypb.Target, session *econtext.SafeSession, lockFuncType sqlparser.LockingFuncType) (*shardActionInfo, error) { info := &shardActionInfo{actionNeeded: nothing} if session.LockSession != nil { if !proto.Equal(target, session.LockSession.Target) { @@ -908,6 +912,22 @@ type shardActionInfo struct { rowsAffected bool } +func (sai *shardActionInfo) TransactionID() int64 { + return sai.transactionID +} + +func (sai *shardActionInfo) ReservedID() int64 { + return sai.reservedID +} + +func (sai *shardActionInfo) RowsAffected() bool { + return sai.rowsAffected +} + +func (sai *shardActionInfo) Alias() *topodatapb.TabletAlias { + return sai.alias +} + func (sai *shardActionInfo) updateTransactionAndReservedID(txID int64, rID int64, alias *topodatapb.TabletAlias, qr *sqltypes.Result) *shardActionInfo { firstTimeRowsAffected := false if txID != 0 && qr != nil && !sai.rowsAffected { diff --git a/go/vt/vtgate/scatter_conn_test.go b/go/vt/vtgate/scatter_conn_test.go index c5d4f350433..ab8680ca5e6 100644 --- a/go/vt/vtgate/scatter_conn_test.go +++ b/go/vt/vtgate/scatter_conn_test.go @@ -21,6 +21,7 @@ import ( "testing" "vitess.io/vitess/go/vt/log" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/mysql/sqlerror" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -100,7 +101,7 @@ func TestExecuteFailOnAutocommit(t *testing.T) { }, Autocommit: false, } - _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}) + _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, econtext.NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}) err := vterrors.Aggregate(errs) require.Error(t, err) require.Contains(t, err.Error(), "in autocommit mode, transactionID should be zero but was: 123") @@ -183,7 +184,7 @@ func TestExecutePanic(t *testing.T) { require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction") }() - _, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}) + _, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, econtext.NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}) } @@ -204,7 +205,7 @@ func TestReservedOnMultiReplica(t *testing.T) { res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa") - session := NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true}) destinations := []key.Destination{key.DestinationShard("0")} for i := 0; i < 10; i++ { executeOnShards(t, ctx, res, keyspace, sc, session, destinations) @@ -351,7 +352,7 @@ func TestReservedBeginTableDriven(t *testing.T) { res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa") t.Run(test.name, func(t *testing.T) { - session := NewSafeSession(&vtgatepb.Session{}) + session := econtext.NewSafeSession(&vtgatepb.Session{}) for _, action := range test.actions { session.Session.InTransaction = action.transaction session.Session.InReservedConn = action.reserved @@ -384,7 +385,7 @@ func TestReservedConnFail(t *testing.T) { _ = hc.AddTestTablet("aa", "1", 1, keyspace, "1", topodatapb.TabletType_REPLICA, true, 1, nil) res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa") - session := NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true}) destinations := []key.Destination{key.DestinationShard("0")} executeOnShards(t, ctx, res, keyspace, sc, session, destinations) diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index d136542d176..124997bea9e 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -20,6 +20,8 @@ import ( "testing" "time" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" @@ -53,7 +55,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { TabletType: tabletType, } - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway @@ -156,7 +158,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { TabletType: tabletType, } - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway @@ -286,7 +288,7 @@ func TestInconsistentStateDetectedBuffering(t *testing.T) { TabletType: tabletType, } - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index 2aafb78af99..b318cb84981 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -22,6 +22,8 @@ import ( "strings" "testing" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -109,7 +111,7 @@ func TestTabletGatewayShuffleTablets(t *testing.T) { ctx := utils.LeakCheckContext(t) hc := discovery.NewFakeHealthCheck(nil) - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} tg := NewTabletGateway(ctx, hc, ts, "local") defer tg.Close(ctx) @@ -183,7 +185,7 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) { TabletType: tabletType, } hc := discovery.NewFakeHealthCheck(nil) - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) @@ -218,7 +220,7 @@ func testTabletGatewayGenericHelper(t *testing.T, ctx context.Context, f func(ct TabletType: tabletType, } hc := discovery.NewFakeHealthCheck(nil) - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // no tablet @@ -306,7 +308,7 @@ func testTabletGatewayTransact(t *testing.T, ctx context.Context, f func(ctx con TabletType: tabletType, } hc := discovery.NewFakeHealthCheck(nil) - ts := &fakeTopoServer{} + ts := &econtext.FakeTopoServer{} tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) @@ -348,7 +350,7 @@ func verifyShardErrors(t *testing.T, err error, wantErrors []string, wantCode vt // TestWithRetry tests the functionality of withRetry function in different circumstances. func TestWithRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - tg := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &fakeTopoServer{}, "cell") + tg := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &econtext.FakeTopoServer{}, "cell") tg.kev = discovery.NewKeyspaceEventWatcher(ctx, tg.srvTopoServer, tg.hc, tg.localCell) defer func() { cancel() diff --git a/go/vt/vtgate/tx_conn.go b/go/vt/vtgate/tx_conn.go index 315484ea499..3ce138bc0e4 100644 --- a/go/vt/vtgate/tx_conn.go +++ b/go/vt/vtgate/tx_conn.go @@ -33,6 +33,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/queryservice" ) @@ -80,7 +81,7 @@ var phaseMessage = map[commitPhase]string{ // Begin begins a new transaction. If one is already in progress, it commits it // and starts a new one. -func (txc *TxConn) Begin(ctx context.Context, session *SafeSession, txAccessModes []sqlparser.TxAccessMode) error { +func (txc *TxConn) Begin(ctx context.Context, session *econtext.SafeSession, txAccessModes []sqlparser.TxAccessMode) error { if session.InTransaction() { if err := txc.Commit(ctx, session); err != nil { return err @@ -102,7 +103,7 @@ func (txc *TxConn) Begin(ctx context.Context, session *SafeSession, txAccessMode // Commit commits the current transaction. The type of commit can be // best effort or 2pc depending on the session setting. -func (txc *TxConn) Commit(ctx context.Context, session *SafeSession) error { +func (txc *TxConn) Commit(ctx context.Context, session *econtext.SafeSession) error { defer session.ResetTx() if !session.InTransaction() { return nil @@ -123,7 +124,7 @@ func (txc *TxConn) Commit(ctx context.Context, session *SafeSession) error { return txc.commitNormal(ctx, session) } -func recordCommitTime(session *SafeSession, twopc bool, startTime time.Time) { +func recordCommitTime(session *econtext.SafeSession, twopc bool, startTime time.Time) { switch { case len(session.ShardSessions) == 0: // No-op @@ -143,7 +144,7 @@ func (txc *TxConn) queryService(ctx context.Context, alias *topodatapb.TabletAli return txc.tabletGateway.QueryServiceByAlias(ctx, alias, nil) } -func (txc *TxConn) commitShard(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error { +func (txc *TxConn) commitShard(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger) error { if s.TransactionId == 0 { return nil } @@ -159,19 +160,19 @@ func (txc *TxConn) commitShard(ctx context.Context, s *vtgatepb.Session_ShardSes } s.TransactionId = 0 s.ReservedId = reservedID - logging.log(nil, s.Target, nil, "commit", false, nil) + logging.Log(nil, s.Target, nil, "commit", false, nil) return nil } -func (txc *TxConn) commitNormal(ctx context.Context, session *SafeSession) error { - if err := txc.runSessions(ctx, session.PreSessions, session.logging, txc.commitShard); err != nil { +func (txc *TxConn) commitNormal(ctx context.Context, session *econtext.SafeSession) error { + if err := txc.runSessions(ctx, session.PreSessions, session.GetLogger(), txc.commitShard); err != nil { _ = txc.Release(ctx, session) return err } // Retain backward compatibility on commit order for the normal session. for i, shardSession := range session.ShardSessions { - if err := txc.commitShard(ctx, shardSession, session.logging); err != nil { + if err := txc.commitShard(ctx, shardSession, session.GetLogger()); err != nil { if i > 0 { nShards := i elipsis := false @@ -197,7 +198,7 @@ func (txc *TxConn) commitNormal(ctx context.Context, session *SafeSession) error } } - if err := txc.runSessions(ctx, session.PostSessions, session.logging, txc.commitShard); err != nil { + if err := txc.runSessions(ctx, session.PostSessions, session.GetLogger(), txc.commitShard); err != nil { // If last commit fails, there will be nothing to rollback. session.RecordWarning(&querypb.QueryWarning{Message: fmt.Sprintf("post-operation transaction had an error: %v", err)}) // With reserved connection we should release them. @@ -209,7 +210,7 @@ func (txc *TxConn) commitNormal(ctx context.Context, session *SafeSession) error } // commit2PC will not used the pinned tablets - to make sure we use the current source, we need to use the gateway's queryservice -func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err error) { +func (txc *TxConn) commit2PC(ctx context.Context, session *econtext.SafeSession) (err error) { // If the number of participants is one or less, then it's a normal commit. if len(session.ShardSessions) <= 1 { return txc.commitNormal(ctx, session) @@ -249,7 +250,7 @@ func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err err } txPhase = Commit2pcPrepare - prepareAction := func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error { + prepareAction := func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger) error { if DebugTwoPc { // Test code to simulate a failure during RM prepare if terr := checkTestFailure(ctx, "RMPrepare_-40_FailNow", s.Target); terr != nil { return terr @@ -257,7 +258,7 @@ func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err err } return txc.tabletGateway.Prepare(ctx, s.Target, s.TransactionId, dtid) } - if err = txc.runSessions(ctx, rmShards, session.logging, prepareAction); err != nil { + if err = txc.runSessions(ctx, rmShards, session.GetLogger(), prepareAction); err != nil { return err } @@ -280,7 +281,7 @@ func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err err } txPhase = Commit2pcPrepareCommit - prepareCommitAction := func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error { + prepareCommitAction := func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger) error { if DebugTwoPc { // Test code to simulate a failure during RM prepare if terr := checkTestFailure(ctx, "RMCommit_-40_FailNow", s.Target); terr != nil { return terr @@ -288,7 +289,7 @@ func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err err } return txc.tabletGateway.CommitPrepared(ctx, s.Target, dtid) } - if err = txc.runSessions(ctx, rmShards, session.logging, prepareCommitAction); err != nil { + if err = txc.runSessions(ctx, rmShards, session.GetLogger(), prepareCommitAction); err != nil { return err } @@ -300,7 +301,7 @@ func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) (err err return nil } -func (txc *TxConn) checkValidCondition(session *SafeSession) error { +func (txc *TxConn) checkValidCondition(session *econtext.SafeSession) error { if len(session.PreSessions) != 0 || len(session.PostSessions) != 0 { return vterrors.VT12001("atomic distributed transaction commit with consistent lookup vindex") } @@ -309,7 +310,7 @@ func (txc *TxConn) checkValidCondition(session *SafeSession) error { func (txc *TxConn) errActionAndLogWarn( ctx context.Context, - session *SafeSession, + session *econtext.SafeSession, txPhase commitPhase, startCommitState querypb.StartCommitState, dtid string, @@ -323,12 +324,12 @@ func (txc *TxConn) errActionAndLogWarn( rollbackErr = txc.Rollback(ctx, session) case Commit2pcPrepare: // Rollback the prepared and unprepared transactions. - rollbackErr = txc.rollbackTx(ctx, dtid, mmShard, rmShards, session.logging) + rollbackErr = txc.rollbackTx(ctx, dtid, mmShard, rmShards, session.GetLogger()) case Commit2pcStartCommit: // Failed to store the commit decision on MM. // If the failure state is certain, then the only option is to rollback the prepared transactions on the RMs. if startCommitState == querypb.StartCommitState_Fail { - rollbackErr = txc.rollbackTx(ctx, dtid, mmShard, rmShards, session.logging) + rollbackErr = txc.rollbackTx(ctx, dtid, mmShard, rmShards, session.GetLogger()) } fallthrough case Commit2pcPrepareCommit: @@ -362,7 +363,7 @@ func createWarningMessage(dtid string, txPhase commitPhase) string { } // Rollback rolls back the current transaction. There are no retries on this operation. -func (txc *TxConn) Rollback(ctx context.Context, session *SafeSession) error { +func (txc *TxConn) Rollback(ctx context.Context, session *econtext.SafeSession) error { if !session.InTransaction() { return nil } @@ -371,7 +372,7 @@ func (txc *TxConn) Rollback(ctx context.Context, session *SafeSession) error { allsessions := append(session.PreSessions, session.ShardSessions...) allsessions = append(allsessions, session.PostSessions...) - err := txc.runSessions(ctx, allsessions, session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error { + err := txc.runSessions(ctx, allsessions, session.GetLogger(), func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger) error { if s.TransactionId == 0 { return nil } @@ -385,7 +386,7 @@ func (txc *TxConn) Rollback(ctx context.Context, session *SafeSession) error { } s.TransactionId = 0 s.ReservedId = reservedID - logging.log(nil, s.Target, nil, "rollback", false, nil) + logging.Log(nil, s.Target, nil, "rollback", false, nil) return nil }) if err != nil { @@ -398,7 +399,7 @@ func (txc *TxConn) Rollback(ctx context.Context, session *SafeSession) error { } // Release releases the reserved connection and/or rollbacks the transaction -func (txc *TxConn) Release(ctx context.Context, session *SafeSession) error { +func (txc *TxConn) Release(ctx context.Context, session *econtext.SafeSession) error { if !session.InTransaction() && !session.InReservedConn() { return nil } @@ -407,7 +408,7 @@ func (txc *TxConn) Release(ctx context.Context, session *SafeSession) error { allsessions := append(session.PreSessions, session.ShardSessions...) allsessions = append(allsessions, session.PostSessions...) - return txc.runSessions(ctx, allsessions, session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error { + return txc.runSessions(ctx, allsessions, session.GetLogger(), func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger) error { if s.ReservedId == 0 && s.TransactionId == 0 { return nil } @@ -426,7 +427,7 @@ func (txc *TxConn) Release(ctx context.Context, session *SafeSession) error { } // ReleaseLock releases the reserved connection used for locking. -func (txc *TxConn) ReleaseLock(ctx context.Context, session *SafeSession) error { +func (txc *TxConn) ReleaseLock(ctx context.Context, session *econtext.SafeSession) error { if !session.InLockSession() { return nil } @@ -445,7 +446,7 @@ func (txc *TxConn) ReleaseLock(ctx context.Context, session *SafeSession) error } // ReleaseAll releases all the shard sessions and lock session. -func (txc *TxConn) ReleaseAll(ctx context.Context, session *SafeSession) error { +func (txc *TxConn) ReleaseAll(ctx context.Context, session *econtext.SafeSession) error { if !session.InTransaction() && !session.InReservedConn() && !session.InLockSession() { return nil } @@ -457,7 +458,7 @@ func (txc *TxConn) ReleaseAll(ctx context.Context, session *SafeSession) error { allsessions = append(allsessions, session.LockSession) } - return txc.runSessions(ctx, allsessions, session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, loggging *executeLogger) error { + return txc.runSessions(ctx, allsessions, session.GetLogger(), func(ctx context.Context, s *vtgatepb.Session_ShardSession, loggging *econtext.ExecuteLogger) error { if s.ReservedId == 0 && s.TransactionId == 0 { return nil } @@ -529,12 +530,12 @@ func (txc *TxConn) resolveTx(ctx context.Context, target *querypb.Target, transa // rollbackTx rollbacks the specified distributed transaction. // Rollbacks happens on the metadata manager and all participants irrespective of the failure. -func (txc *TxConn) rollbackTx(ctx context.Context, dtid string, mmShard *vtgatepb.Session_ShardSession, participants []*vtgatepb.Session_ShardSession, logging *executeLogger) error { +func (txc *TxConn) rollbackTx(ctx context.Context, dtid string, mmShard *vtgatepb.Session_ShardSession, participants []*vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger) error { var errs []error if mmErr := txc.rollbackMM(ctx, dtid, mmShard); mmErr != nil { errs = append(errs, mmErr) } - if rmErr := txc.runSessions(ctx, participants, logging, func(ctx context.Context, session *vtgatepb.Session_ShardSession, logger *executeLogger) error { + if rmErr := txc.runSessions(ctx, participants, logging, func(ctx context.Context, session *vtgatepb.Session_ShardSession, logger *econtext.ExecuteLogger) error { return txc.tabletGateway.RollbackPrepared(ctx, session.Target, dtid, session.TransactionId) }); rmErr != nil { errs = append(errs, rmErr) @@ -575,7 +576,7 @@ func (txc *TxConn) resumeCommit(ctx context.Context, target *querypb.Target, tra } // runSessions executes the action for all shardSessions in parallel and returns a consolidated error. -func (txc *TxConn) runSessions(ctx context.Context, shardSessions []*vtgatepb.Session_ShardSession, logging *executeLogger, action func(context.Context, *vtgatepb.Session_ShardSession, *executeLogger) error) error { +func (txc *TxConn) runSessions(ctx context.Context, shardSessions []*vtgatepb.Session_ShardSession, logging *econtext.ExecuteLogger, action func(context.Context, *vtgatepb.Session_ShardSession, *econtext.ExecuteLogger) error) error { // Fastpath. if len(shardSessions) == 1 { return action(ctx, shardSessions[0], logging) diff --git a/go/vt/vtgate/tx_conn_test.go b/go/vt/vtgate/tx_conn_test.go index 9d49626f6f1..333094569c8 100644 --- a/go/vt/vtgate/tx_conn_test.go +++ b/go/vt/vtgate/tx_conn_test.go @@ -26,6 +26,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" + "vitess.io/vitess/go/event/syslogger" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/test/utils" @@ -51,7 +53,7 @@ func TestTxConnBegin(t *testing.T) { session := &vtgatepb.Session{} // begin - safeSession := NewSafeSession(session) + safeSession := econtext.NewSafeSession(session) err := sc.txConn.Begin(ctx, safeSession, nil) require.NoError(t, err) wantSession := vtgatepb.Session{InTransaction: true} @@ -75,7 +77,7 @@ func TestTxConnCommitFailure(t *testing.T) { // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rssm[0], queries, session, false, false, nullResultsObserver{}) wantSession := vtgatepb.Session{ InTransaction: true, @@ -176,7 +178,7 @@ func TestTxConnCommitFailureAfterNonAtomicCommitMaxShards(t *testing.T) { // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) wantSession := vtgatepb.Session{ InTransaction: true, ShardSessions: []*vtgatepb.Session_ShardSession{}, @@ -230,7 +232,7 @@ func TestTxConnCommitFailureBeforeNonAtomicCommitMaxShards(t *testing.T) { // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) wantSession := vtgatepb.Session{ InTransaction: true, ShardSessions: []*vtgatepb.Session_ShardSession{}, @@ -282,7 +284,7 @@ func TestTxConnCommitSuccess(t *testing.T) { sc.txConn.mode = vtgatepb.TransactionMode_MULTI // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) wantSession := vtgatepb.Session{ InTransaction: true, @@ -335,7 +337,7 @@ func TestTxConnReservedCommitSuccess(t *testing.T) { sc.txConn.mode = vtgatepb.TransactionMode_MULTI // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) wantSession := vtgatepb.Session{ InTransaction: true, @@ -420,7 +422,7 @@ func TestTxConnReservedOn2ShardTxOn1ShardAndCommit(t *testing.T) { sc.txConn.mode = vtgatepb.TransactionMode_MULTI // Sequence the executes to ensure shard session order - session := NewSafeSession(&vtgatepb.Session{InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InReservedConn: true}) // this will create reserved connections against all tablets _, errs := sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{}) @@ -515,7 +517,7 @@ func TestTxConnReservedOn2ShardTxOn1ShardAndRollback(t *testing.T) { sc.txConn.mode = vtgatepb.TransactionMode_MULTI // Sequence the executes to ensure shard session order - session := NewSafeSession(&vtgatepb.Session{InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InReservedConn: true}) // this will create reserved connections against all tablets _, errs := sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{}) @@ -611,7 +613,7 @@ func TestTxConnCommitOrderFailure1(t *testing.T) { queries := []*querypb.BoundQuery{{Sql: "query1"}} // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) session.SetCommitOrder(vtgatepb.CommitOrder_PRE) @@ -646,7 +648,7 @@ func TestTxConnCommitOrderFailure2(t *testing.T) { }} // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(context.Background(), nil, rss1, queries, session, false, false, nullResultsObserver{}) session.SetCommitOrder(vtgatepb.CommitOrder_PRE) @@ -680,7 +682,7 @@ func TestTxConnCommitOrderFailure3(t *testing.T) { }} // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) session.SetCommitOrder(vtgatepb.CommitOrder_PRE) @@ -722,7 +724,7 @@ func TestTxConnCommitOrderSuccess(t *testing.T) { }} // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) wantSession := vtgatepb.Session{ InTransaction: true, @@ -820,7 +822,7 @@ func TestTxConnReservedCommitOrderSuccess(t *testing.T) { }} // Sequence the executes to ensure commit order - session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) wantSession := vtgatepb.Session{ InTransaction: true, @@ -957,7 +959,7 @@ func TestTxConnCommit2PC(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PC") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) session.TransactionMode = vtgatepb.TransactionMode_TWOPC @@ -974,7 +976,7 @@ func TestTxConnCommit2PCOneParticipant(t *testing.T) { ctx := utils.LeakCheckContext(t) sc, sbc0, _, rss0, _, _ := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PCOneParticipant") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) session.TransactionMode = vtgatepb.TransactionMode_TWOPC require.NoError(t, @@ -987,7 +989,7 @@ func TestTxConnCommit2PCCreateTransactionFail(t *testing.T) { sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PCCreateTransactionFail") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{}) @@ -1009,7 +1011,7 @@ func TestTxConnCommit2PCPrepareFail(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PCPrepareFail") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) @@ -1035,7 +1037,7 @@ func TestTxConnCommit2PCStartCommitFail(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PCStartCommitFail") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) @@ -1054,7 +1056,7 @@ func TestTxConnCommit2PCStartCommitFail(t *testing.T) { sbc0.ResetCounter() sbc1.ResetCounter() - session = NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session = econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) @@ -1077,7 +1079,7 @@ func TestTxConnCommit2PCCommitPreparedFail(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PCCommitPreparedFail") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) @@ -1097,7 +1099,7 @@ func TestTxConnCommit2PCConcludeTransactionFail(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TestTxConnCommit2PCConcludeTransactionFail") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) @@ -1117,7 +1119,7 @@ func TestTxConnRollback(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TxConnRollback") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) require.NoError(t, @@ -1133,7 +1135,7 @@ func TestTxConnReservedRollback(t *testing.T) { sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, ctx, "TxConnReservedRollback") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) require.NoError(t, @@ -1170,7 +1172,7 @@ func TestTxConnReservedRollbackFailure(t *testing.T) { sc, sbc0, sbc1, rss0, rss1, rss01 := newTestTxConnEnv(t, ctx, "TxConnReservedRollback") - session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) + session := econtext.NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true}) sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{}) sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false, nullResultsObserver{}) @@ -1449,7 +1451,7 @@ func TestTxConnMultiGoSessions(t *testing.T) { Keyspace: "0", }, }} - err := txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *executeLogger) error { + err := txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *econtext.ExecuteLogger) error { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "err %s", s.Target.Keyspace) }) want := "err 0" @@ -1464,7 +1466,7 @@ func TestTxConnMultiGoSessions(t *testing.T) { Keyspace: "1", }, }} - err = txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *executeLogger) error { + err = txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *econtext.ExecuteLogger) error { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "err %s", s.Target.Keyspace) }) want = "err 0\nerr 1" @@ -1472,7 +1474,7 @@ func TestTxConnMultiGoSessions(t *testing.T) { wantCode := vtrpcpb.Code_INTERNAL assert.Equal(t, wantCode, vterrors.Code(err), "error code") - err = txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *executeLogger) error { + err = txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *econtext.ExecuteLogger) error { return nil }) require.NoError(t, err) @@ -1515,7 +1517,7 @@ func TestTxConnAccessModeReset(t *testing.T) { tcases := []struct { name string - f func(ctx context.Context, session *SafeSession) error + f func(ctx context.Context, session *econtext.SafeSession) error }{{ name: "begin-commit", f: sc.txConn.Commit, @@ -1532,7 +1534,7 @@ func TestTxConnAccessModeReset(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.name, func(t *testing.T) { - safeSession := NewSafeSession(&vtgatepb.Session{ + safeSession := econtext.NewSafeSession(&vtgatepb.Session{ Options: &querypb.ExecuteOptions{ TransactionAccessMode: []querypb.ExecuteOptions_TransactionAccessMode{querypb.ExecuteOptions_READ_ONLY}, }, diff --git a/go/vt/vtgate/vschema_manager.go b/go/vt/vtgate/vschema_manager.go index 2b6761f4a8e..62ea2cd3455 100644 --- a/go/vt/vtgate/vschema_manager.go +++ b/go/vt/vtgate/vschema_manager.go @@ -33,8 +33,6 @@ import ( vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) -var _ VSchemaOperator = (*VSchemaManager)(nil) - // VSchemaManager is used to watch for updates to the vschema and to implement // the DDL commands to add / remove vindexes type VSchemaManager struct { diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 1628a6253eb..8bab05479dd 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -52,6 +52,7 @@ import ( "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" + econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" vtschema "vitess.io/vitess/go/vt/vtgate/schema" "vitess.io/vitess/go/vt/vtgate/txresolver" @@ -488,7 +489,7 @@ func (vtg *VTGate) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConn if bvErr := sqltypes.ValidateBindVariables(bindVariables); bvErr != nil { err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", bvErr) } else { - safeSession := NewSafeSession(session) + safeSession := econtext.NewSafeSession(session) qr, err = vtg.executor.Execute(ctx, mysqlCtx, "Execute", safeSession, sql, bindVariables) safeSession.RemoveInternalSavepoint() } @@ -545,7 +546,7 @@ func (vtg *VTGate) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MyS defer vtg.timings.Record(statsKey, time.Now()) - safeSession := NewSafeSession(session) + safeSession := econtext.NewSafeSession(session) var err error if bvErr := sqltypes.ValidateBindVariables(bindVariables); bvErr != nil { err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", bvErr) @@ -579,7 +580,7 @@ func (vtg *VTGate) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MyS // same effect as if a "rollback" statement was executed, but does not affect the query // statistics. func (vtg *VTGate) CloseSession(ctx context.Context, session *vtgatepb.Session) error { - return vtg.executor.CloseSession(ctx, NewSafeSession(session)) + return vtg.executor.CloseSession(ctx, econtext.NewSafeSession(session)) } // Prepare supports non-streaming prepare statement query with multi shards @@ -594,7 +595,7 @@ func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql s goto handleError } - fld, err = vtg.executor.Prepare(ctx, "Prepare", NewSafeSession(session), sql, bindVariables) + fld, err = vtg.executor.Prepare(ctx, "Prepare", econtext.NewSafeSession(session), sql, bindVariables) if err == nil { return session, fld, nil }