diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index e9b1d3d7712..7e1dcddeb99 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -22,6 +22,7 @@ import ( "io" "sort" "strings" + "sync" "sync/atomic" "time" @@ -133,6 +134,8 @@ type ( resultsObserver resultsObserver + // this protects the interOpStats and shardsStats fields from concurrent writes + mu sync.Mutex // this is a map of the number of rows that every primitive has returned // if this field is nil, it means that we are not logging operator traffic interOpStats map[engine.Primitive]engine.RowsReceived @@ -537,21 +540,29 @@ func (vc *vcursorImpl) ExecutePrimitive(ctx context.Context, primitive engine.Pr } func (vc *vcursorImpl) logOpTraffic(primitive engine.Primitive, res *sqltypes.Result) { - if vc.interOpStats != nil { - rows := vc.interOpStats[primitive] - if res == nil { - rows = append(rows, 0) - } else { - rows = append(rows, len(res.Rows)) - } - vc.interOpStats[primitive] = rows + if vc.interOpStats == nil { + return + } + + vc.mu.Lock() + defer vc.mu.Unlock() + + rows := vc.interOpStats[primitive] + if res == nil { + rows = append(rows, 0) + } else { + rows = append(rows, len(res.Rows)) } + vc.interOpStats[primitive] = rows } func (vc *vcursorImpl) logShardsQueried(primitive engine.Primitive, shardsNb int) { - if vc.shardsStats != nil { - vc.shardsStats[primitive] += engine.ShardsQueried(shardsNb) + if vc.shardsStats == nil { + return } + vc.mu.Lock() + defer vc.mu.Unlock() + vc.shardsStats[primitive] += engine.ShardsQueried(shardsNb) } func (vc *vcursorImpl) ExecutePrimitiveStandalone(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {