Skip to content

Commit

Permalink
Faster Prepared Statement Execution by Using Raw SQL for Caching (#17777
Browse files Browse the repository at this point in the history
)

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
Signed-off-by: Andres Taylor <andres@planetscale.com>
Co-authored-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
harshit-gangal and systay authored Mar 3, 2025
1 parent 1582d5b commit 0af627a
Show file tree
Hide file tree
Showing 173 changed files with 3,145 additions and 3,453 deletions.
11 changes: 9 additions & 2 deletions go/cmd/vtgateclienttest/services/callerid.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,18 @@ func (c *callerIDClient) checkCallerID(ctx context.Context, received string) (bo
return true, fmt.Errorf("SUCCESS: callerid matches")
}

func (c *callerIDClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *callerIDClient) Execute(
ctx context.Context,
mysqlCtx vtgateservice.MySQLConnection,
session *vtgatepb.Session,
sql string,
bindVariables map[string]*querypb.BindVariable,
prepared bool,
) (*vtgatepb.Session, *sqltypes.Result, error) {
if ok, err := c.checkCallerID(ctx, sql); ok {
return session, nil, err
}
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c *callerIDClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand Down
18 changes: 12 additions & 6 deletions go/cmd/vtgateclienttest/services/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ func echoQueryResult(vals map[string]any) *sqltypes.Result {
return qr
}

func (c *echoClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *echoClient) Execute(
ctx context.Context,
mysqlCtx vtgateservice.MySQLConnection,
session *vtgatepb.Session,
sql string,
bindVariables map[string]*querypb.BindVariable,
prepared bool,
) (*vtgatepb.Session, *sqltypes.Result, error) {
if strings.HasPrefix(sql, EchoPrefix) {
return session, echoQueryResult(map[string]any{
"callerId": callerid.EffectiveCallerIDFromContext(ctx),
Expand All @@ -107,7 +114,7 @@ func (c *echoClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLCo
"session": session,
}), nil
}
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c *echoClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
Expand Down Expand Up @@ -173,14 +180,13 @@ func (c *echoClient) VStream(ctx context.Context, tabletType topodatapb.TabletTy
return c.fallbackClient.VStream(ctx, tabletType, vgtid, filter, flags, callback)
}

func (c *echoClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
func (c *echoClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) {
if strings.HasPrefix(sql, EchoPrefix) {
return session, echoQueryResult(map[string]any{
"callerId": callerid.EffectiveCallerIDFromContext(ctx),
"query": sql,
"bindVars": bindVariables,
"session": session,
}).Fields, nil
}).Fields, 0, nil
}
return c.fallbackClient.Prepare(ctx, session, sql, bindVariables)
return c.fallbackClient.Prepare(ctx, session, sql)
}
19 changes: 13 additions & 6 deletions go/cmd/vtgateclienttest/services/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,21 @@ func trimmedRequestToError(received string) error {
}
}

func (c *errorClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *errorClient) Execute(
ctx context.Context,
mysqlCtx vtgateservice.MySQLConnection,
session *vtgatepb.Session,
sql string,
bindVariables map[string]*querypb.BindVariable,
prepared bool,
) (*vtgatepb.Session, *sqltypes.Result, error) {
if err := requestToPartialError(sql, session); err != nil {
return session, nil, err
}
if err := requestToError(sql); err != nil {
return session, nil, err
}
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c *errorClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand All @@ -139,14 +146,14 @@ func (c *errorClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.
return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback)
}

func (c *errorClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
func (c *errorClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) {
if err := requestToPartialError(sql, session); err != nil {
return session, nil, err
return session, nil, 0, err
}
if err := requestToError(sql); err != nil {
return session, nil, err
return session, nil, 0, err
}
return c.fallbackClient.Prepare(ctx, session, sql, bindVariables)
return c.fallbackClient.Prepare(ctx, session, sql)
}

func (c *errorClient) CloseSession(ctx context.Context, session *vtgatepb.Session) error {
Expand Down
15 changes: 11 additions & 4 deletions go/cmd/vtgateclienttest/services/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ func newFallbackClient(fallback vtgateservice.VTGateService) fallbackClient {
return fallbackClient{fallback: fallback}
}

func (c fallbackClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
return c.fallback.Execute(ctx, mysqlCtx, session, sql, bindVariables)
func (c fallbackClient) Execute(
ctx context.Context,
mysqlCtx vtgateservice.MySQLConnection,
session *vtgatepb.Session,
sql string,
bindVariables map[string]*querypb.BindVariable,
prepared bool,
) (*vtgatepb.Session, *sqltypes.Result, error) {
return c.fallback.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c fallbackClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand All @@ -52,8 +59,8 @@ func (c fallbackClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservic
return c.fallback.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback)
}

func (c fallbackClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
return c.fallback.Prepare(ctx, session, sql, bindVariables)
func (c fallbackClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) {
return c.fallback.Prepare(ctx, session, sql)
}

func (c fallbackClient) CloseSession(ctx context.Context, session *vtgatepb.Session) error {
Expand Down
13 changes: 10 additions & 3 deletions go/cmd/vtgateclienttest/services/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ func newTerminalClient() *terminalClient {
return &terminalClient{}
}

func (c *terminalClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *terminalClient) Execute(
ctx context.Context,
mysqlCtx vtgateservice.MySQLConnection,
session *vtgatepb.Session,
sql string,
bindVariables map[string]*querypb.BindVariable,
prepared bool,
) (*vtgatepb.Session, *sqltypes.Result, error) {
if sql == "quit://" {
log.Fatal("Received quit:// query. Going down.")
}
Expand All @@ -63,8 +70,8 @@ func (c *terminalClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservi
return session, errTerminal
}

func (c *terminalClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
return session, nil, errTerminal
func (c *terminalClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) {
return session, nil, 0, errTerminal
}

func (c *terminalClient) CloseSession(ctx context.Context, session *vtgatepb.Session) error {
Expand Down
65 changes: 16 additions & 49 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -1122,7 +1121,7 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool
return c.writeErrorPacketFromErrorAndLog(err)
}

fieldSent := false
receivedResult := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
prepare := c.PrepareData[stmtID]
Expand All @@ -1132,8 +1131,8 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool
return io.EOF
}

if !fieldSent {
fieldSent = true
if !receivedResult {
receivedResult = true

if len(qr.Fields) == 0 {
sendFinished = true
Expand All @@ -1157,7 +1156,7 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool
})

// If no field was sent, we expect an error.
if !fieldSent {
if !receivedResult {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = sqlerror.NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
Expand Down Expand Up @@ -1200,10 +1199,8 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
query := c.parseComPrepare(data)
c.recycleReadPacket()

var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
var err error
queries, err = handler.Env().Parser().SplitStatementToPieces(query)
queries, err := handler.Env().Parser().SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
Expand All @@ -1212,56 +1209,26 @@ func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
log.Errorf("Conn %v: can not prepare multiple statements", c)
return c.writeErrorPacketFromErrorAndLog(err)
}
} else {
queries = []string{query}
query = queries[0]
}

fld, paramsCount, err := handler.ComPrepare(c, query)
if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}

// Populate PrepareData
c.StatementID++
prepare := &PrepareData{
StatementID: c.StatementID,
PrepareStmt: queries[0],
}

statement, err := handler.Env().Parser().ParseStrictDDL(query)
if err != nil {
log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
if !c.writeErrorPacketFromErrorAndLog(err) {
return false
}
PrepareStmt: query,
ParamsCount: paramsCount,
ParamsType: make([]int32, paramsCount),
BindVars: make(map[string]*querypb.BindVariable, paramsCount),
}

paramsCount := uint16(0)
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
case *sqlparser.Argument:
if strings.HasPrefix(node.Name, "v") {
paramsCount++
}
}
return true, nil
}, statement)

if paramsCount > 0 {
prepare.ParamsCount = paramsCount
prepare.ParamsType = make([]int32, paramsCount)
prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount)
}

bindVars := make(map[string]*querypb.BindVariable, paramsCount)
for i := range uint16(paramsCount) {
parameterID := fmt.Sprintf("v%d", i+1)
bindVars[parameterID] = &querypb.BindVariable{}
}

c.PrepareData[c.StatementID] = prepare

fld, err := handler.ComPrepare(c, queries[0], bindVars)
if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}

if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil {
if err := c.writePrepare(fld, prepare); err != nil {
log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err)
return false
}
Expand Down
Loading

0 comments on commit 0af627a

Please sign in to comment.