diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index 775118aee73..8bb8a466b21 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -473,6 +473,10 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.T // BeginStreamExecute starts a transaction and runs an Execute. func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.TransactionState, err error) { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -650,6 +654,9 @@ func (conn *gRPCQueryClient) StreamHealth(ctx context.Context, callback func(*qu // VStream starts a VReplication stream. func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -695,6 +702,9 @@ func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb. // VStreamRows streams rows of a query from the specified starting point. func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamRowsClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -737,6 +747,9 @@ func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdat // VStreamTables streams rows of a query from the specified starting point. func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(*binlogdatapb.VStreamTablesResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamTablesClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -777,6 +790,9 @@ func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogd // VStreamResults streams rows of a query from the specified starting point. func (conn *gRPCQueryClient) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamResultsClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -856,6 +872,9 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *qu // ReserveBeginStreamExecute implements the queryservice interface func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedTransactionState, err error) { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -967,6 +986,9 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb // ReserveStreamExecute implements the queryservice interface func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedState, err error) { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -1060,6 +1082,9 @@ func (conn *gRPCQueryClient) Release(ctx context.Context, target *querypb.Target // GetSchema implements the queryservice interface func (conn *gRPCQueryClient) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { diff --git a/go/vt/vttablet/grpctabletconn/conn_test.go b/go/vt/vttablet/grpctabletconn/conn_test.go index fb182bfe2e4..70e30e337bc 100644 --- a/go/vt/vttablet/grpctabletconn/conn_test.go +++ b/go/vt/vttablet/grpctabletconn/conn_test.go @@ -17,13 +17,21 @@ limitations under the License. package grpctabletconn import ( + "context" + "fmt" "io" "net" "os" + "sync" "testing" + "github.com/stretchr/testify/require" "google.golang.org/grpc" + "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" + queryservicepb "vitess.io/vitess/go/vt/proto/queryservice" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" "vitess.io/vitess/go/vt/vttablet/tabletconntest" @@ -113,3 +121,111 @@ func TestGRPCTabletAuthConn(t *testing.T) { }, }, service, f) } + +// mockQueryClient is a mock query client that returns an error from Streaming calls, +// but only after storing the context that was passed to the RPC. +type mockQueryClient struct { + lastCallCtx context.Context + queryservicepb.QueryClient +} + +func (m *mockQueryClient) StreamExecute(ctx context.Context, in *querypb.StreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_StreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) BeginStreamExecute(ctx context.Context, in *querypb.BeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_BeginStreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) ReserveStreamExecute(ctx context.Context, in *querypb.ReserveStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveStreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) ReserveBeginStreamExecute(ctx context.Context, in *querypb.ReserveBeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveBeginStreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStream(ctx context.Context, in *binlogdatapb.VStreamRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStreamRows(ctx context.Context, in *binlogdatapb.VStreamRowsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamRowsClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStreamTables(ctx context.Context, in *binlogdatapb.VStreamTablesRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamTablesClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStreamResults(ctx context.Context, in *binlogdatapb.VStreamResultsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamResultsClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) GetSchema(ctx context.Context, in *querypb.GetSchemaRequest, opts ...grpc.CallOption) (queryservicepb.Query_GetSchemaClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +var _ queryservicepb.QueryClient = (*mockQueryClient)(nil) + +// TestGoRoutineLeakPrevention tests that after all the RPCs that stream queries, we end up closing the context that was passed to it, to prevent go routines from being leaked. +func TestGoRoutineLeakPrevention(t *testing.T) { + mqc := &mockQueryClient{} + qc := &gRPCQueryClient{ + mu: sync.RWMutex{}, + cc: &grpc.ClientConn{}, + c: mqc, + } + _ = qc.StreamExecute(context.Background(), nil, "", nil, 0, 0, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _, _ = qc.BeginStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _, _ = qc.ReserveBeginStreamExecute(context.Background(), nil, nil, nil, "", nil, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _, _ = qc.ReserveStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStream(context.Background(), &binlogdatapb.VStreamRequest{}, func(events []*binlogdatapb.VEvent) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStreamRows(context.Background(), &binlogdatapb.VStreamRowsRequest{}, func(response *binlogdatapb.VStreamRowsResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStreamResults(context.Background(), nil, "", func(response *binlogdatapb.VStreamResultsResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStreamTables(context.Background(), &binlogdatapb.VStreamTablesRequest{}, func(response *binlogdatapb.VStreamTablesResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.GetSchema(context.Background(), nil, querypb.SchemaTableType_TABLES, nil, func(schemaRes *querypb.GetSchemaResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) +}