Skip to content

Commit

Permalink
Fix Data race in semi-join (#17417)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
GuptaManan100 authored Dec 28, 2024
1 parent 9383943 commit c25802d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 7 deletions.
20 changes: 20 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,23 @@ func TestTimeZones(t *testing.T) {
})
}
}

// TestSemiJoin tests that the semi join works as intended.
func TestSemiJoin(t *testing.T) {
mcmp, closer := start(t)
defer closer()

for i := 1; i <= 1000; i++ {
mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i))
mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i))
}

// Test that the semi join works as intended
for _, mode := range []string{"oltp", "olap"} {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1")
})
}
}
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type fakePrimitive struct {
// sendErr is sent at the end of the stream if it's set.
sendErr error

log []string
noLog bool
log []string

allResultsInOneCall bool

Expand Down Expand Up @@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar
}

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
if !f.noLog {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
}
if f.results == nil {
return f.sendErr
}
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/semi_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma

// TryStreamExecute performs a streaming exec.
func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowAdded := false
var rowAdded atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
if len(rresult.Rows) > 0 && !rowAdded {
result.Rows = append(result.Rows, lrow)
rowAdded = true
if len(rresult.Rows) > 0 {
rowAdded.Store(true)
}
return nil
})
if err != nil {
return err
}
if rowAdded.Load() {
result.Rows = append(result.Rows, lrow)
}
}
return callback(result)
})
Expand Down
79 changes: 79 additions & 0 deletions go/vt/vtgate/engine/semi_join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"
"testing"

"vitess.io/vitess/go/test/utils"
Expand Down Expand Up @@ -159,3 +160,81 @@ func TestSemiJoinStreamExecute(t *testing.T) {
"4|d|dd",
))
}

// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution
// to ensure we have no data races.
func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) {
leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
), sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"3|c|cc",
"4|d|dd",
),
},
async: true,
}
rightFields := sqltypes.MakeTestFields(
"col4|col5|col6",
"int64|varchar|varchar",
)
rightPrim := &fakePrimitive{
// we'll return non-empty results for rows 2 and 4
results: sqltypes.MakeTestStreamingResults(rightFields,
"4|d|dd",
"---",
"---",
"5|e|ee",
"6|f|ff",
"7|g|gg",
),
async: true,
noLog: true,
}

jn := &SemiJoin{
Left: leftPrim,
Right: rightPrim,
Vars: map[string]int{
"bv": 1,
},
}
var res *sqltypes.Result
var mu sync.Mutex
err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
if res == nil {
res = result
} else {
res.Rows = append(res.Rows, result.Rows...)
}
return nil
})
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`StreamExecute true`,
})
// We'll get all the rows back in left primitive, since we're returning the same set of rows
// from the right primitive that makes them all qualify.
expectResultAnyOrder(t, res, sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
))
}

0 comments on commit c25802d

Please sign in to comment.