Skip to content

Commit

Permalink
feat: fix engine and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
GuptaManan100 committed Aug 20, 2024
1 parent d6b9361 commit c4585d7
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 4 deletions.
10 changes: 10 additions & 0 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"reflect"
"strings"
"testing"
"time"

"golang.org/x/sync/errgroup"

Expand All @@ -41,6 +42,9 @@ type fakePrimitive struct {

log []string

// sleepTime is the time for which the fake primitive sleeps before returning the results.
sleepTime time.Duration

allResultsInOneCall bool

async bool
Expand Down Expand Up @@ -71,6 +75,9 @@ func (f *fakePrimitive) GetTableName() string {

func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields))
if f.sleepTime != 0 {
time.Sleep(f.sleepTime)
}
if f.results == nil {
return nil, f.sendErr
}
Expand All @@ -85,6 +92,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
if f.sleepTime != 0 {
time.Sleep(f.sleepTime)
}
if f.results == nil {
return f.sendErr
}
Expand Down
8 changes: 6 additions & 2 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ var _ SessionActions = (*noopVCursor)(nil)

// noopVCursor is used to build other vcursors.
type noopVCursor struct {
inTx bool
inTx bool
queryTimeout int
}

// MySQLVersion implements VCursor.
Expand Down Expand Up @@ -298,7 +299,10 @@ func (t *noopVCursor) SetQueryTimeout(maxExecutionTime int64) {
}

func (t *noopVCursor) GetQueryTimeout(queryTimeoutFromComments int) int {
return queryTimeoutFromComments
if queryTimeoutFromComments != 0 {
return queryTimeoutFromComments
}
return t.queryTimeout
}

func (t *noopVCursor) SetSkipQueryPlanCache(context.Context, bool) error {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/timeout_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (t *TimeoutHandler) TryExecute(ctx context.Context, vcursor VCursor, bindVa
ctx, cancel := addQueryTimeout(ctx, vcursor, t.Timeout)
defer cancel()

var complete chan any
complete := make(chan any)
go func() {
res, err = t.Input.TryExecute(ctx, vcursor, bindVars, wantfields)
close(complete)
Expand All @@ -73,7 +73,7 @@ func (t *TimeoutHandler) TryStreamExecute(ctx context.Context, vcursor VCursor,
ctx, cancel := addQueryTimeout(ctx, vcursor, t.Timeout)
defer cancel()

var complete chan any
complete := make(chan any)
go func() {
err = t.Input.TryStreamExecute(ctx, vcursor, bindVars, wantfields, callback)
close(complete)
Expand Down
85 changes: 85 additions & 0 deletions go/vt/vtgate/engine/timeout_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package engine

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
)

// TestTimeoutHandler tests timeout handler primitive.
func TestTimeoutHandler(t *testing.T) {
tests := []struct {
name string
input *TimeoutHandler
vc VCursor
wantErr string
}{
{
name: "No timeout",
input: NewTimeoutHandler(&fakePrimitive{
results: nil,
sleepTime: 100 * time.Millisecond,
}, 0),
vc: &noopVCursor{},
wantErr: "",
}, {
name: "Timeout without failure",
input: NewTimeoutHandler(&fakePrimitive{
results: nil,
sleepTime: 100 * time.Millisecond,
}, 1000),
vc: &noopVCursor{},
wantErr: "",
}, {
name: "Timeout in session",
input: NewTimeoutHandler(&fakePrimitive{
results: nil,
sleepTime: 2 * time.Second,
}, 0),
vc: &noopVCursor{
queryTimeout: 100,
},
wantErr: "VT15001: Query execution was interrupted, maximum statement execution time exceeded",
}, {
name: "Timeout in comments",
input: NewTimeoutHandler(&fakePrimitive{
results: nil,
sleepTime: 2 * time.Second,
}, 100),
vc: &noopVCursor{},
wantErr: "VT15001: Query execution was interrupted, maximum statement execution time exceeded",
}, {
name: "Timeout in both",
input: NewTimeoutHandler(&fakePrimitive{
results: nil,
sleepTime: 2 * time.Second,
}, 100),
vc: &noopVCursor{
queryTimeout: 4000,
},
wantErr: "VT15001: Query execution was interrupted, maximum statement execution time exceeded",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := tt.input.TryExecute(context.Background(), tt.vc, nil, false)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
err = tt.input.TryStreamExecute(context.Background(), tt.vc, nil, false, func(result *sqltypes.Result) error {
return nil
})
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}

0 comments on commit c4585d7

Please sign in to comment.