Skip to content

Commit

Permalink
Always make sure to escape all strings
Browse files Browse the repository at this point in the history
Don't directly interpolate these strings. We don't know of any user
controllable ways to do this, but it's still too risky to ever do this.
We always need to escape all strings.

Ideally we refactor this as well to use better statement binding in the
future.

Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink committed Jan 29, 2025
1 parent fd1186c commit 359128d
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 55 deletions.
4 changes: 1 addition & 3 deletions go/test/endtoend/vreplication/vdiff_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ func getVDiffInfo(json string) *vdiffInfo {
}

func encodeString(in string) string {
var buf strings.Builder
sqltypes.NewVarChar(in).EncodeSQL(&buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

// generateMoreCustomers creates additional test data for better tests
Expand Down
22 changes: 10 additions & 12 deletions go/vt/binlog/binlogplayer/binlog_player.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ func (blp *BinlogPlayer) setVReplicationState(state binlogdatapb.VReplicationWor
})
}
blp.blplStats.State.Store(state.String())
query := fmt.Sprintf("update _vt.vreplication set state='%v', message=%v where id=%v", state.String(), encodeString(MessageTruncate(message)), blp.uid)
query := fmt.Sprintf("update _vt.vreplication set state=%v, message=%v where id=%v", encodeString(state.String()), encodeString(MessageTruncate(message)), blp.uid)
if _, err := blp.dbClient.ExecuteFetch(query, 1); err != nil {
return fmt.Errorf("could not set state: %v: %v", query, err)
}
Expand Down Expand Up @@ -637,9 +637,9 @@ func CreateVReplication(workflow string, source *binlogdatapb.BinlogSource, posi
protoutil.SortBinlogSourceTables(source)
return fmt.Sprintf("insert into _vt.vreplication "+
"(workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name, workflow_type, workflow_sub_type, defer_secondary_keys, options) "+
"values (%v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %v, %s)",
"values (%v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %v, %s)",
encodeString(workflow), encodeString(source.String()), encodeString(position), maxTPS, maxReplicationLag,
timeUpdated, binlogdatapb.VReplicationWorkflowState_Running.String(), encodeString(dbName), workflowType,
timeUpdated, encodeString(binlogdatapb.VReplicationWorkflowState_Running.String()), encodeString(dbName), workflowType,
workflowSubType, deferSecondaryKeys, encodeString("{}"))
}

Expand All @@ -649,9 +649,9 @@ func CreateVReplicationState(workflow string, source *binlogdatapb.BinlogSource,
protoutil.SortBinlogSourceTables(source)
return fmt.Sprintf("insert into _vt.vreplication "+
"(workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name, workflow_type, workflow_sub_type, options) "+
"values (%v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %s)",
"values (%v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %s)",
encodeString(workflow), encodeString(source.String()), encodeString(position), throttler.MaxRateModuleDisabled,
throttler.ReplicationLagModuleDisabled, time.Now().Unix(), state.String(), encodeString(dbName),
throttler.ReplicationLagModuleDisabled, time.Now().Unix(), encodeString(state.String()), encodeString(dbName),
workflowType, workflowSubType, encodeString("{}"))
}

Expand Down Expand Up @@ -694,15 +694,15 @@ func GenerateUpdateTimeThrottled(uid int32, timeThrottledUnix int64, componentTh
// StartVReplicationUntil returns a statement to start the replication with a stop position.
func StartVReplicationUntil(uid int32, pos string) string {
return fmt.Sprintf(
"update _vt.vreplication set state='%v', stop_pos=%v where id=%v",
binlogdatapb.VReplicationWorkflowState_Running.String(), encodeString(pos), uid)
"update _vt.vreplication set state=%v, stop_pos=%v where id=%v",
encodeString(binlogdatapb.VReplicationWorkflowState_Running.String()), encodeString(pos), uid)
}

// StopVReplication returns a statement to stop the replication.
func StopVReplication(uid int32, message string) string {
return fmt.Sprintf(
"update _vt.vreplication set state='%v', message=%v where id=%v",
binlogdatapb.VReplicationWorkflowState_Stopped.String(), encodeString(MessageTruncate(message)), uid)
"update _vt.vreplication set state=%v, message=%v where id=%v",
encodeString(binlogdatapb.VReplicationWorkflowState_Stopped.String()), encodeString(MessageTruncate(message)), uid)
}

// DeleteVReplication returns a statement to delete the replication.
Expand All @@ -717,9 +717,7 @@ func MessageTruncate(msg string) string {
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

// ReadVReplicationPos returns a statement to query the gtid for a
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtctl/vdiff_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func newTestVDiffEnv(t testing.TB, ctx context.Context, sourceShards, targetShar
// But this is one statement per stream.
env.tmc.setVRResults(
primary.tablet,
fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos='%s', message='synchronizing for vdiff' where id=%d", vdiffSourceGtid, j+1),
fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos=%s, message='synchronizing for vdiff' where id=%d", sqltypes.EncodeStringSQL(vdiffSourceGtid), j+1),
&sqltypes.Result{},
)
}
Expand Down
1 change: 0 additions & 1 deletion go/vt/vtctl/workflow/resharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ func (rs *resharder) createStreams(ctx context.Context) error {
if err != nil {
return err
}
optionsJSON = fmt.Sprintf("'%s'", optionsJSON)
for _, source := range rs.sourceShards {
if !key.KeyRangeIntersect(target.KeyRange, source.KeyRange) {
continue
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtctl/workflow/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,8 @@ func (ts *trafficSwitcher) getReverseVReplicationUpdateQuery(targetCell string,
}

if ts.optCells != "" || ts.optTabletTypes != "" {
query := fmt.Sprintf("update _vt.vreplication set cell = '%s', tablet_types = '%s', options = '%s' where workflow = '%s' and db_name = '%s'",
ts.optCells, ts.optTabletTypes, options, ts.ReverseWorkflowName(), dbname)
query := fmt.Sprintf("update _vt.vreplication set cell = %s, tablet_types = %s, options = %s where workflow = %s and db_name = %s",
sqltypes.EncodeStringSQL(ts.optCells), sqltypes.EncodeStringSQL(ts.optTabletTypes), sqltypes.EncodeStringSQL(options), sqltypes.EncodeStringSQL(ts.ReverseWorkflowName()), sqltypes.EncodeStringSQL(dbname))
return query
}
return ""
Expand Down Expand Up @@ -941,8 +941,8 @@ func (ts *trafficSwitcher) createReverseVReplication(ctx context.Context) error
// For non-reference tables we return an error if there's no primary
// vindex as it's not clear what to do.
if len(vtable.ColumnVindexes) > 0 && len(vtable.ColumnVindexes[0].Columns) > 0 {
inKeyrange = fmt.Sprintf(" where in_keyrange(%s, '%s.%s', '%s')", sqlparser.String(vtable.ColumnVindexes[0].Columns[0]),
ts.SourceKeyspaceName(), vtable.ColumnVindexes[0].Name, key.KeyRangeString(source.GetShard().KeyRange))
inKeyrange = fmt.Sprintf(" where in_keyrange(%s, %s, %s)", sqlparser.String(vtable.ColumnVindexes[0].Columns[0]),
sqlparser.String(sqlparser.NewTableNameWithQualifier(vtable.ColumnVindexes[0].Name, ts.SourceKeyspaceName())), sqltypes.EncodeStringSQL(key.KeyRangeString(source.GetShard().KeyRange)))
} else {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no primary vindex found for the %s table in the %s keyspace",
vtable.Name.String(), ts.SourceKeyspaceName())
Expand Down Expand Up @@ -1184,7 +1184,7 @@ func (ts *trafficSwitcher) freezeTargetVReplication(ctx context.Context) error {
// re-invoked after a freeze, it will skip all the previous steps
err := ts.ForAllTargets(func(target *MigrationTarget) error {
ts.Logger().Infof("Marking target streams frozen for workflow %s db_name %s", ts.WorkflowName(), target.GetPrimary().DbName())
query := fmt.Sprintf("update _vt.vreplication set message = '%s' where db_name=%s and workflow=%s", Frozen,
query := fmt.Sprintf("update _vt.vreplication set message = %s where db_name=%s and workflow=%s", encodeString(Frozen),
encodeString(target.GetPrimary().DbName()), encodeString(ts.WorkflowName()))
_, err := ts.TabletManagerClient().VReplicationExec(ctx, target.GetPrimary().Tablet, query)
return err
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vtctl/workflow/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package workflow

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -627,9 +626,7 @@ func ReverseWorkflowName(workflow string) string {
// this public, but it doesn't belong in package workflow. Maybe package sqltypes,
// or maybe package sqlescape?
func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func getRenameFileName(tableName string) string {
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vttablet/endtoend/vstreamer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package endtoend

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -472,9 +471,7 @@ func expectLogs(ctx context.Context, t *testing.T, query string, eventCh chan []
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func validateSchemaInserted(client *framework.QueryClient, ddl string) bool {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vttablet/onlineddl/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1571,8 +1571,8 @@ func (e *Executor) ExecuteWithVReplication(ctx context.Context, onlineDDL *schem

{
// temporary hack. todo: this should be done when inserting any _vt.vreplication record across all workflow types
query := fmt.Sprintf("update _vt.vreplication set workflow_type = %d where workflow = '%s'",
binlogdatapb.VReplicationWorkflowType_OnlineDDL, v.workflow)
query := fmt.Sprintf("update _vt.vreplication set workflow_type = %d where workflow = %s",
binlogdatapb.VReplicationWorkflowType_OnlineDDL, sqltypes.EncodeStringSQL(v.workflow))
if _, err := e.vreplicationExec(ctx, tablet.Tablet, query); err != nil {
return vterrors.Wrapf(err, "VReplicationExec(%v, %s)", tablet.Tablet, query)
}
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vttablet/tabletmanager/vdiff/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package vdiff
import (
"context"
"fmt"
"strings"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -59,9 +58,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare
// Utility functions

func encodeString(in string) string {
var buf strings.Builder
sqltypes.NewVarChar(in).EncodeSQL(&buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func pkColsToGroupByParams(pkCols []int, collationEnv *collations.Environment) []*engine.GroupByParams {
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vttablet/tabletmanager/vreplication/insert_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ func NewInsertGenerator(state binlogdatapb.VReplicationWorkflowState, dbname str
func (ig *InsertGenerator) AddRow(workflow string, bls *binlogdatapb.BinlogSource, pos, cell, tabletTypes string,
workflowType binlogdatapb.VReplicationWorkflowType, workflowSubType binlogdatapb.VReplicationWorkflowSubType, deferSecondaryKeys bool, options string) {
if options == "" {
options = "'{}'"
options = "{}"
}
protoutil.SortBinlogSourceTables(bls)
fmt.Fprintf(ig.buf, "%s(%v, %v, %v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %v, %v)",
fmt.Fprintf(ig.buf, "%s(%v, %v, %v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %v, %v)",
ig.prefix,
encodeString(workflow),
encodeString(bls.String()),
Expand All @@ -66,12 +66,12 @@ func (ig *InsertGenerator) AddRow(workflow string, bls *binlogdatapb.BinlogSourc
encodeString(cell),
encodeString(tabletTypes),
ig.now,
ig.state,
encodeString(ig.state),
encodeString(ig.dbname),
workflowType,
workflowSubType,
deferSecondaryKeys,
options,
encodeString(options),
)
ig.prefix = ", "
}
Expand Down
6 changes: 2 additions & 4 deletions go/vt/vttablet/tabletmanager/vreplication/vreplicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ func (vr *vreplicator) setState(state binlogdatapb.VReplicationWorkflowState, me
})
}
vr.stats.State.Store(state.String())
query := fmt.Sprintf("update _vt.vreplication set state='%v', message=%v where id=%v", state, encodeString(binlogplayer.MessageTruncate(message)), vr.id)
query := fmt.Sprintf("update _vt.vreplication set state=%v, message=%v where id=%v", encodeString(state.String()), encodeString(binlogplayer.MessageTruncate(message)), vr.id)
// If we're batching a transaction, then include the state update
// in the current transaction batch.
if vr.dbClient.InTransaction && vr.dbClient.maxBatchSize > 0 {
Expand All @@ -528,9 +528,7 @@ func (vr *vreplicator) setState(state binlogdatapb.VReplicationWorkflowState, me
}

func encodeString(in string) string {
var buf strings.Builder
sqltypes.NewVarChar(in).EncodeSQL(&buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func (vr *vreplicator) getSettingFKCheck() error {
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vttablet/tabletserver/schema/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package schema

import (
"bytes"
"context"
"fmt"
"sync"
Expand Down Expand Up @@ -243,9 +242,7 @@ func (tr *Tracker) saveCurrentSchemaToDb(ctx context.Context, gtid, ddl string,
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

// MustReloadSchemaOnDDL returns true if the ddl is for the db which is part of the workflow and is not an online ddl artifact
Expand Down
4 changes: 1 addition & 3 deletions go/vt/vttablet/tabletserver/vstreamer/vstreamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,7 @@ type extColInfo struct {
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func (vs *vstreamer) processJournalEvent(vevents []*binlogdatapb.VEvent, plan *streamerPlan, rows mysql.Rows) ([]*binlogdatapb.VEvent, error) {
Expand Down
5 changes: 1 addition & 4 deletions go/vt/wrangler/keyspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package wrangler

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -125,7 +124,5 @@ func (wr *Wrangler) updateShardRecords(ctx context.Context, keyspace string, sha
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

0 comments on commit 359128d

Please sign in to comment.