Skip to content

Commit 7800fe5

Browse files
[release-17.0] vtgate/engine: Fix race condition in join logic (#14435) (#14440)
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com> Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com>
1 parent 3709c53 commit 7800fe5

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

go/vt/vtgate/engine/join.go

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222
"strings"
23+
"sync"
2324
"sync/atomic"
2425

2526
"vitess.io/vitess/go/sqltypes"
@@ -115,22 +116,31 @@ func bindvarForType(t querypb.Type) *querypb.BindVariable {
115116

116117
// TryStreamExecute performs a streaming exec.
117118
func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
118-
var fieldNeeded atomic.Bool
119-
fieldNeeded.Store(wantfields)
120-
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error {
119+
var mu sync.Mutex
120+
// We need to use this atomic since we're also reading this
121+
// value outside of it being locked with the mu lock.
122+
// This is still racy, but worst case it means that we may
123+
// retrieve the right hand side fields twice instead of once.
124+
var fieldsSent atomic.Bool
125+
fieldsSent.Store(!wantfields)
126+
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
121127
joinVars := make(map[string]*querypb.BindVariable)
122128
for _, lrow := range lresult.Rows {
123129
for k, col := range jn.Vars {
124130
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
125131
}
126132
var rowSent atomic.Bool
127-
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error {
133+
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), !fieldsSent.Load(), func(rresult *sqltypes.Result) error {
134+
// This needs to be locking since it's not safe to just use
135+
// fieldsSent. This is because we can't have a race between
136+
// checking fieldsSent and then actually calling the callback
137+
// and in parallel another goroutine doing the same. That
138+
// can lead to out of order execution of the callback. So the callback
139+
// itself and the check need to be covered by the same lock.
140+
mu.Lock()
141+
defer mu.Unlock()
128142
result := &sqltypes.Result{}
129-
if fieldNeeded.Load() {
130-
// This code is currently unreachable because the first result
131-
// will always be just the field info, which will cause the outer
132-
// wantfields code path to be executed. But this may change in the future.
133-
fieldNeeded.Store(false)
143+
if fieldsSent.CompareAndSwap(false, true) {
134144
result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
135145
}
136146
for _, rrow := range rresult.Rows {
@@ -154,8 +164,15 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
154164
return callback(result)
155165
}
156166
}
157-
if fieldNeeded.Load() {
158-
fieldNeeded.Store(false)
167+
// This needs to be locking since it's not safe to just use
168+
// fieldsSent. This is because we can't have a race between
169+
// checking fieldsSent and then actually calling the callback
170+
// and in parallel another goroutine doing the same. That
171+
// can lead to out of order execution of the callback. So the callback
172+
// itself and the check need to be covered by the same lock.
173+
mu.Lock()
174+
defer mu.Unlock()
175+
if fieldsSent.CompareAndSwap(false, true) {
159176
for k := range jn.Vars {
160177
joinVars[k] = sqltypes.NullBindVariable
161178
}

go/vt/vtgate/engine/scalar_aggregation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor
129129
var current []sqltypes.Value
130130
var curDistincts []sqltypes.Value
131131
var fields []*querypb.Field
132-
fieldsSent := false
133132
var mu sync.Mutex
133+
fieldsSent := !wantfields
134134

135135
err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error {
136136
// as the underlying primitive call is not sync

0 commit comments

Comments
 (0)