Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resultsObserver to ScatterConn #16638

Merged
merged 9 commits into from
Aug 29, 2024
10 changes: 5 additions & 5 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ func (e *Executor) executeSPInAllSessions(ctx context.Context, safeSession *Safe
})
queries = append(queries, &querypb.BoundQuery{Sql: sql})
}
qr, errs = e.ExecuteMultiShard(ctx, nil, rss, queries, safeSession, false /*autocommit*/, ignoreMaxMemoryRows)
qr, errs = e.ExecuteMultiShard(ctx, nil, rss, queries, safeSession, false /*autocommit*/, ignoreMaxMemoryRows, nullResultsObserver{})
err := vterrors.Aggregate(errs)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1454,13 +1454,13 @@ func parseAndValidateQuery(query string, parser *sqlparser.Parser) (sqlparser.St
}

// ExecuteMultiShard implements the IExecutor interface
func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool) (qr *sqltypes.Result, errs []error) {
return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows)
func (e *Executor) ExecuteMultiShard(ctx context.Context, primitive engine.Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool, resultsObserver resultsObserver) (qr *sqltypes.Result, errs []error) {
return e.scatterConn.ExecuteMultiShard(ctx, primitive, rss, queries, session, autocommit, ignoreMaxMemoryRows, resultsObserver)
}

// StreamExecuteMulti implements the IExecutor interface
func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error) []error {
return e.scatterConn.StreamExecuteMulti(ctx, primitive, query, rss, vars, session, autocommit, callback)
func (e *Executor) StreamExecuteMulti(ctx context.Context, primitive engine.Primitive, query string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, session *SafeSession, autocommit bool, callback func(reply *sqltypes.Result) error, resultsObserver resultsObserver) []error {
return e.scatterConn.StreamExecuteMulti(ctx, primitive, query, rss, vars, session, autocommit, callback, resultsObserver)
}

// ExecuteLock implements the IExecutor interface
Expand Down
54 changes: 40 additions & 14 deletions go/vt/vtgate/legacy_scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestLegacyExecuteFailOnAutocommit(t *testing.T) {
},
Autocommit: false,
}
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{})
err := vterrors.Aggregate(errs)
require.Error(t, err)
require.Contains(t, err.Error(), "in autocommit mode, transactionID should be zero but was: 123")
Expand All @@ -123,7 +123,7 @@ func TestScatterConnExecuteMulti(t *testing.T) {
}
}

qr, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(nil), false /*autocommit*/, false)
qr, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(nil), false /*autocommit*/, false, nullResultsObserver{})
return qr, vterrors.Aggregate(errs)
})
}
Expand All @@ -143,7 +143,7 @@ func TestScatterConnStreamExecuteMulti(t *testing.T) {
defer mu.Unlock()
qr.AppendResult(r)
return nil
})
}, nullResultsObserver{})
return qr, vterrors.Aggregate(errors)
})
}
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestMaxMemoryRows(t *testing.T) {
sbc0.SetResults([]*sqltypes.Result{tworows, tworows})
sbc1.SetResults([]*sqltypes.Result{tworows, tworows})

_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, test.ignoreMaxMemoryRows)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, test.ignoreMaxMemoryRows, nullResultsObserver{})
if test.ignoreMaxMemoryRows {
require.NoError(t, err)
} else {
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestLegaceHealthCheckFailsOnReservedConnections(t *testing.T) {
})
}

_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, nullResultsObserver{})
require.Error(t, vterrors.Aggregate(errs))
}

Expand All @@ -365,10 +365,21 @@ func executeOnShardsReturnsErr(t *testing.T, ctx context.Context, res *srvtopo.R
})
}

_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, nullResultsObserver{})
return vterrors.Aggregate(errs)
}

type recordingResultsObserver struct {
mu sync.Mutex
recorded []*sqltypes.Result
}

func (o *recordingResultsObserver) observe(result *sqltypes.Result) {
mu.Lock()
o.recorded = append(o.recorded, result)
mu.Unlock()
}

func TestMultiExecs(t *testing.T) {
ctx := utils.LeakCheckContext(t)
createSandbox("TestMultiExecs")
Expand Down Expand Up @@ -409,9 +420,17 @@ func TestMultiExecs(t *testing.T) {
},
},
}
results := []*sqltypes.Result{
{Info: "r0"},
{Info: "r1"},
}
sbc0.SetResults(results[0:1])
sbc1.SetResults(results[1:2])

observer := recordingResultsObserver{}

session := NewSafeSession(&vtgatepb.Session{})
_, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false)
_, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, &observer)
require.NoError(t, vterrors.Aggregate(err))
if len(sbc0.Queries) == 0 || len(sbc1.Queries) == 0 {
t.Fatalf("didn't get expected query")
Expand All @@ -428,8 +447,12 @@ func TestMultiExecs(t *testing.T) {
if !reflect.DeepEqual(sbc1.Queries[0].BindVariables, wantVars1) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars1)
}
assert.ElementsMatch(t, results, observer.recorded)

sbc0.Queries = nil
sbc1.Queries = nil
sbc0.SetResults(results[0:1])
sbc1.SetResults(results[1:2])

rss = []*srvtopo.ResolvedShard{
{
Expand All @@ -455,15 +478,18 @@ func TestMultiExecs(t *testing.T) {
"bv1": sqltypes.Int64BindVariable(1),
},
}

observer = recordingResultsObserver{}
_ = sc.StreamExecuteMulti(ctx, nil, "query", rss, bvs, session, false /* autocommit */, func(*sqltypes.Result) error {
return nil
})
}, &observer)
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, wantVars0) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars0)
}
if !reflect.DeepEqual(sbc1.Queries[0].BindVariables, wantVars1) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars1)
}
assert.ElementsMatch(t, results, observer.recorded)
}

func TestScatterConnSingleDB(t *testing.T) {
Expand All @@ -487,27 +513,27 @@ func TestScatterConnSingleDB(t *testing.T) {
// TransactionMode_SINGLE in session
session := NewSafeSession(&vtgatepb.Session{InTransaction: true, TransactionMode: vtgatepb.TransactionMode_SINGLE})
queries := []*querypb.BoundQuery{{Sql: "query1"}}
_, errors := sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
_, errors := sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{})
require.Error(t, errors[0])
assert.Contains(t, errors[0].Error(), want)

// TransactionMode_SINGLE in txconn
sc.txConn.mode = vtgatepb.TransactionMode_SINGLE
session = NewSafeSession(&vtgatepb.Session{InTransaction: true})
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{})
require.Error(t, errors[0])
assert.Contains(t, errors[0].Error(), want)

// TransactionMode_MULTI in txconn. Should not fail.
sc.txConn.mode = vtgatepb.TransactionMode_MULTI
session = NewSafeSession(&vtgatepb.Session{InTransaction: true})
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
_, errors = sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false, nullResultsObserver{})
require.Empty(t, errors)
}

Expand Down
33 changes: 27 additions & 6 deletions go/vt/vtgate/scatter_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ type shardActionFunc func(rs *srvtopo.ResolvedShard, i int) error
// the results and errors for the caller.
type shardActionTransactionFunc func(rs *srvtopo.ResolvedShard, i int, shardActionInfo *shardActionInfo) (*shardActionInfo, error)

type (
resultsObserver interface {
observe(*sqltypes.Result)
}
Comment on lines +76 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this method not take in the shard information to track where this result set comes from?

nullResultsObserver struct{}
)

func (nullResultsObserver) observe(*sqltypes.Result) {}

// NewScatterConn creates a new ScatterConn.
func NewScatterConn(statsName string, txConn *TxConn, gw *TabletGateway) *ScatterConn {
// this only works with TabletGateway
Expand Down Expand Up @@ -146,6 +155,7 @@ func (stc *ScatterConn) ExecuteMultiShard(
session *SafeSession,
autocommit bool,
ignoreMaxMemoryRows bool,
resultsObserver resultsObserver,
) (qr *sqltypes.Result, errs []error) {

if len(rss) != len(queries) {
Expand Down Expand Up @@ -260,6 +270,10 @@ func (stc *ScatterConn) ExecuteMultiShard(
mu.Lock()
defer mu.Unlock()

if innerqr != nil {
resultsObserver.observe(innerqr)
}

Comment on lines +273 to +276
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it is given that innerqr will not be nil.
Just below, we append the result to the output result set as-sis without checking.

// Don't append more rows if row count is exceeded.
if ignoreMaxMemoryRows || len(qr.Rows) <= maxMemoryRows {
qr.AppendResult(innerqr)
Expand Down Expand Up @@ -354,11 +368,18 @@ func (stc *ScatterConn) StreamExecuteMulti(
session *SafeSession,
autocommit bool,
callback func(reply *sqltypes.Result) error,
resultsObserver resultsObserver,
) []error {
if session.InLockSession() && session.TriggerLockHeartBeat() {
go stc.runLockQuery(ctx, session)
}

observedCallback := func(reply *sqltypes.Result) error {
if reply != nil {
resultsObserver.observe(reply)
}
return callback(reply)
}
allErrors := stc.multiGoTransaction(
ctx,
"StreamExecute",
Expand Down Expand Up @@ -407,41 +428,41 @@ func (stc *ScatterConn) StreamExecuteMulti(

switch info.actionNeeded {
case nothing:
err = qs.StreamExecute(ctx, rs.Target, query, bindVars[i], transactionID, reservedID, opts, callback)
err = qs.StreamExecute(ctx, rs.Target, query, bindVars[i], transactionID, reservedID, opts, observedCallback)
if err != nil {
retryRequest(func() {
// we seem to have lost our connection. it was a reserved connection, let's try to recreate it
info.actionNeeded = reserve
var state queryservice.ReservedState
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], 0 /*transactionId*/, opts, callback)
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], 0 /*transactionId*/, opts, observedCallback)
reservedID = state.ReservedID
alias = state.TabletAlias
})
}
case begin:
var state queryservice.TransactionState
state, err = qs.BeginStreamExecute(ctx, rs.Target, session.SavePoints(), query, bindVars[i], reservedID, opts, callback)
state, err = qs.BeginStreamExecute(ctx, rs.Target, session.SavePoints(), query, bindVars[i], reservedID, opts, observedCallback)
transactionID = state.TransactionID
alias = state.TabletAlias
if err != nil {
retryRequest(func() {
// we seem to have lost our connection. it was a reserved connection, let's try to recreate it
info.actionNeeded = reserveBegin
var state queryservice.ReservedTransactionState
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, callback)
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, observedCallback)
transactionID = state.TransactionID
reservedID = state.ReservedID
alias = state.TabletAlias
})
}
case reserve:
var state queryservice.ReservedState
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], transactionID, opts, callback)
state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], transactionID, opts, observedCallback)
reservedID = state.ReservedID
alias = state.TabletAlias
case reserveBegin:
var state queryservice.ReservedTransactionState
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, callback)
state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, observedCallback)
transactionID = state.TransactionID
reservedID = state.ReservedID
alias = state.TabletAlias
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestExecuteFailOnAutocommit(t *testing.T) {
},
Autocommit: false,
}
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)
_, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{})
err := vterrors.Aggregate(errs)
require.Error(t, err)
require.Contains(t, err.Error(), "in autocommit mode, transactionID should be zero but was: 123")
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestExecutePanic(t *testing.T) {
require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction")
}()

_, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)
_, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{})

}

Expand Down
Loading
Loading