Skip to content

Commit

Permalink
basic skeleton of a working proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
demmer committed Nov 2, 2023
1 parent 0e00f2e commit 55cdd40
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 106 deletions.
200 changes: 101 additions & 99 deletions go/vt/vtgateproxy/mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/vtgateconn"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
Expand All @@ -43,10 +44,7 @@ import (
"vitess.io/vitess/go/vt/vttls"

querypb "vitess.io/vitess/go/vt/proto/query"
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"

"github.com/google/uuid"
)

var (
Expand Down Expand Up @@ -87,47 +85,35 @@ type proxyHandler struct {
mysql.UnimplementedHandler
mu sync.Mutex

proxy *VTGateProxy
connections map[*mysql.Conn]bool
proxy *VTGateProxy
}

func newProxyHandler(proxy *VTGateProxy) *proxyHandler {
return &proxyHandler{
proxy: proxy,
connections: make(map[*mysql.Conn]bool),
proxy: proxy,
}
}

func (vh *proxyHandler) NewConnection(c *mysql.Conn) {
vh.mu.Lock()
defer vh.mu.Unlock()
vh.connections[c] = true
}

func (vh *proxyHandler) numConnections() int {
vh.mu.Lock()
defer vh.mu.Unlock()
return len(vh.connections)
func (ph *proxyHandler) NewConnection(c *mysql.Conn) {
}

func (vh *proxyHandler) ComResetConnection(c *mysql.Conn) {
func (ph *proxyHandler) ComResetConnection(c *mysql.Conn) {
ctx := context.Background()
session := vh.session(c)
if session.InTransaction {
session := ph.session(c)
if session.SessionPb().InTransaction {
defer atomic.AddInt32(&busyConnections, -1)
}
err := vh.proxy.CloseSession(ctx, session)
err := ph.proxy.CloseSession(ctx, session)
if err != nil {
log.Errorf("Error happened in transaction rollback: %v", err)
}
}

func (vh *proxyHandler) ConnectionClosed(c *mysql.Conn) {
func (ph *proxyHandler) ConnectionClosed(c *mysql.Conn) {
// Rollback if there is an ongoing transaction. Ignore error.
defer func() {
vh.mu.Lock()
defer vh.mu.Unlock()
delete(vh.connections, c)
ph.mu.Lock()
defer ph.mu.Unlock()
}()

var ctx context.Context
Expand All @@ -138,11 +124,11 @@ func (vh *proxyHandler) ConnectionClosed(c *mysql.Conn) {
} else {
ctx = context.Background()
}
session := vh.session(c)
if session.InTransaction {
session := ph.session(c)
if session.SessionPb().InTransaction {
defer atomic.AddInt32(&busyConnections, -1)
}
_ = vh.proxy.CloseSession(ctx, session)
_ = ph.proxy.CloseSession(ctx, session)
}

// Regexp to extract parent span id over the sql query
Expand Down Expand Up @@ -179,7 +165,7 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co
return startSpanTestable(ctx, query, label, trace.NewSpan, trace.NewFromString)
}

func (vh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
func (ph *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
ctx := context.Background()
var cancel context.CancelFunc
if *mysqlQueryTimeout != 0 {
Expand Down Expand Up @@ -207,21 +193,26 @@ func (vh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql
"VTGate MySQL Connector" /* subcomponent: part of the client */)
ctx = callerid.NewContext(ctx, ef, im)

session := vh.session(c)
if !session.InTransaction {
session := ph.session(c)
if session != nil && !session.SessionPb().InTransaction {
atomic.AddInt32(&busyConnections, 1)
}
defer func() {
if !session.InTransaction {
if session == nil || !session.SessionPb().InTransaction {
atomic.AddInt32(&busyConnections, -1)
}
}()

if session.Options.Workload == querypb.ExecuteOptions_OLAP {
err := vh.proxy.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback)
return mysql.NewSQLErrorFromError(err)
}
session, result, err := vh.proxy.Execute(ctx, session, query, make(map[string]*querypb.BindVariable))
/*
XXX/demmer figure out OLAP
if session.Options.Workload == querypb.ExecuteOptions_OLAP {
err := ph.proxy.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback)
return mysql.NewSQLErrorFromError(err)
}
*/

result, err := ph.proxy.Execute(ctx, session, query, make(map[string]*querypb.BindVariable))

if err := mysql.NewSQLErrorFromError(err); err != nil {
return err
Expand All @@ -230,21 +221,21 @@ func (vh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql
return callback(result)
}

func fillInTxStatusFlags(c *mysql.Conn, session *vtgatepb.Session) {
if session.InTransaction {
func fillInTxStatusFlags(c *mysql.Conn, session *vtgateconn.VTGateSession) {
if session.SessionPb().InTransaction {
c.StatusFlags |= mysql.ServerStatusInTrans
} else {
c.StatusFlags &= mysql.NoServerStatusInTrans
}
if session.Autocommit {
if session.SessionPb().Autocommit {
c.StatusFlags |= mysql.ServerStatusAutocommit
} else {
c.StatusFlags &= mysql.NoServerStatusAutocommit
}
}

// ComPrepare is the handler for command prepare.
func (vh *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
func (ph *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
var ctx context.Context
var cancel context.CancelFunc
if *mysqlQueryTimeout != 0 {
Expand All @@ -268,25 +259,25 @@ func (vh *proxyHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[str
"VTGateProxy MySQL Connector" /* subcomponent: part of the client */)
ctx = callerid.NewContext(ctx, ef, im)

session := vh.session(c)
if !session.InTransaction {
session := ph.session(c)
if !session.SessionPb().InTransaction {
atomic.AddInt32(&busyConnections, 1)
}
defer func() {
if !session.InTransaction {
if !session.SessionPb().InTransaction {
atomic.AddInt32(&busyConnections, -1)
}
}()

session, fld, err := vh.proxy.Prepare(ctx, session, query, bindVars)
session, fld, err := ph.proxy.Prepare(ctx, session, query, bindVars)
err = mysql.NewSQLErrorFromError(err)
if err != nil {
return nil, err
}
return fld, nil
}

func (vh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
func (ph *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
var ctx context.Context
var cancel context.CancelFunc
if *mysqlQueryTimeout != 0 {
Expand All @@ -310,21 +301,25 @@ func (vh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData
"VTGateProxy MySQL Connector" /* subcomponent: part of the client */)
ctx = callerid.NewContext(ctx, ef, im)

session := vh.session(c)
if !session.InTransaction {
session := ph.session(c)
if !session.SessionPb().InTransaction {
atomic.AddInt32(&busyConnections, 1)
}
defer func() {
if !session.InTransaction {
if !session.SessionPb().InTransaction {
atomic.AddInt32(&busyConnections, -1)
}
}()

if session.Options.Workload == querypb.ExecuteOptions_OLAP {
err := vh.proxy.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback)
return mysql.NewSQLErrorFromError(err)
}
_, qr, err := vh.proxy.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars)
/*
XXX/demmer figure out OLAP
if session.Options.Workload == querypb.ExecuteOptions_OLAP {
err := ph.proxy.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback)
return mysql.NewSQLErrorFromError(err)
}
*/

qr, err := ph.proxy.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars)
if err != nil {
err = mysql.NewSQLErrorFromError(err)
return err
Expand All @@ -334,43 +329,45 @@ func (vh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData
return callback(qr)
}

func (vh *proxyHandler) WarningCount(c *mysql.Conn) uint16 {
return uint16(len(vh.session(c).GetWarnings()))
func (ph *proxyHandler) WarningCount(c *mysql.Conn) uint16 {
return uint16(len(ph.session(c).SessionPb().GetWarnings()))
}

// ComBinlogDumpGTID is part of the mysql.Handler interface.
func (vh *proxyHandler) ComBinlogDumpGTID(c *mysql.Conn, gtidSet mysql.GTIDSet) error {
func (ph *proxyHandler) ComBinlogDumpGTID(c *mysql.Conn, gtidSet mysql.GTIDSet) error {
return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "ComBinlogDumpGTID")
}

func (vh *proxyHandler) session(c *mysql.Conn) *vtgatepb.Session {
session, _ := c.ClientData.(*vtgatepb.Session)
func (ph *proxyHandler) session(c *mysql.Conn) *vtgateconn.VTGateSession {
session, _ := c.ClientData.(*vtgateconn.VTGateSession)
if session == nil {
u, _ := uuid.NewUUID()
session = &vtgatepb.Session{
Options: &querypb.ExecuteOptions{
IncludedFields: querypb.ExecuteOptions_ALL,
Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload),

// The collation field of ExecuteOption is set right before an execution.
},
Autocommit: true,
DDLStrategy: *defaultDDLStrategy,
SessionUUID: u.String(),
EnableSystemSettings: *sysVarSetEnabled,
options := &querypb.ExecuteOptions{
IncludedFields: querypb.ExecuteOptions_ALL,
Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload),
}

if c.Capabilities&mysql.CapabilityClientFoundRows != 0 {
session.Options.ClientFoundRows = true
options.ClientFoundRows = true
}

var err error
session, err = ph.proxy.NewSession(options)
if err != nil {
log.Errorf("error creating new session for %s: %v", c.GetRawConn().RemoteAddr().String(), err)
}

if session != nil {
c.ClientData = session
}
c.ClientData = session
}

return session
}

var mysqlListener *mysql.Listener
var mysqlUnixListener *mysql.Listener
var sigChan chan os.Signal
var vtgateHandle *proxyHandler
var proxyHandle *proxyHandler

// initTLSConfig inits tls config for the given mysql listener
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error {
Expand Down Expand Up @@ -426,10 +423,10 @@ func initMySQLProtocol() {

// Create a Listener.
var err error
vtgateHandle = newProxyHandler(vtGateProxy)
proxyHandle = newProxyHandler(vtGateProxy)
if *mysqlServerPort >= 0 {
log.Infof("Mysql Server listening on Port %d", *mysqlServerPort)
mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, vtgateHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol)
mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol)
if err != nil {
log.Exitf("mysql.NewListener failed: %v", err)
}
Expand Down Expand Up @@ -458,7 +455,7 @@ func initMySQLProtocol() {
// Let's create this unix socket with permissions to all users. In this way,
// clients can connect to vtgate mysql server without being vtgate user
oldMask := syscall.Umask(000)
mysqlUnixListener, err = newMysqlUnixSocket(*mysqlServerSocketPath, authServer, vtgateHandle)
mysqlUnixListener, err = newMysqlUnixSocket(*mysqlServerSocketPath, authServer, proxyHandle)
_ = syscall.Umask(oldMask)
if err != nil {
log.Exitf("mysql.NewListener failed: %v", err)
Expand Down Expand Up @@ -531,30 +528,35 @@ func shutdownMysqlProtocolAndDrain() {
func rollbackAtShutdown() {
defer log.Flush()

// Close all open connections. If they're waiting for reads, this will cause
// them to error out, which will automatically rollback open transactions.
func() {
if vtgateHandle != nil {
vtgateHandle.mu.Lock()
defer vtgateHandle.mu.Unlock()
for c := range vtgateHandle.connections {
if c != nil {
log.Infof("Rolling back transactions associated with connection ID: %v", c.ConnectionID)
c.Close()
// XXX/demmer figure out numConnections and rollback
/*
// Close all open connections. If they're waiting for reads, this will cause
// them to error out, which will automatically rollback open transactions.
func() {
if proxyHandle != nil {
proxyHandle.mu.Lock()
defer proxyHandle.mu.Unlock()
for c := range proxyHandle.connections {
if c != nil {
log.Infof("Rolling back transactions associated with connection ID: %v", c.ConnectionID)
c.Close()
}
}
}
}
}()

// If vtgate is instead busy executing a query, the number of open conns
// will be non-zero. Give another second for those queries to finish.
for i := 0; i < 100; i++ {
if vtgateHandle.numConnections() == 0 {
log.Infof("All connections have been rolled back.")
return
}
time.Sleep(10 * time.Millisecond)
}
}()
// If vtgate is instead busy executing a query, the number of open conns
// will be non-zero. Give another second for those queries to finish.
for i := 0; i < 100; i++ {
if proxyHandle.numConnections() == 0 {
log.Infof("All connections have been rolled back.")
return
}
time.Sleep(10 * time.Millisecond)
}
*/
log.Errorf("All connections did not go idle. Shutting down anyway.")
}

Expand Down
Loading

0 comments on commit 55cdd40

Please sign in to comment.