Skip to content

Commit

Permalink
Stash the connection attributes on the conn struct (#152)
Browse files Browse the repository at this point in the history
* Stash the connection attributes on the conn struct

* Clean up code style around upstream server changes
  • Loading branch information
jscheinblum authored and dedelala committed Jan 8, 2025
1 parent 5ddcc0e commit 6e74221
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 24 deletions.
5 changes: 5 additions & 0 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ type Conn struct {
// It is set during the initial handshake.
UserData Getter

// ConnectionAttributes stores attributes set in the connection phase when
// attributes from the client are sent. This is arbitrary key/value pairs
// sent by the client.
ConnectionAttributes ConnectionAttributesMap

bufferedReader *bufio.Reader
flushTimer *time.Timer
flushDelay time.Duration
Expand Down
4 changes: 4 additions & 0 deletions go/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ const (
// implemented authentication methods.
type AuthMethodDescription string

// Map of client key/value pairs sent by the client during
// the connection phase
type ConnectionAttributesMap map[string]string

// Supported auth forms.
const (
// MysqlNativePassword uses a salt and transmits a hash on the wire.
Expand Down
46 changes: 26 additions & 20 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
return
}
user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response)
user, clientAuthMethod, clientAuthResponse, clientAttributes, err := l.parseClientHandshakePacket(c, true, response)
if err != nil {
log.Errorf("Cannot parse client handshake response from %s: %v", c, err)
return
}
c.ConnectionAttributes = clientAttributes

c.recycleReadPacket()

Expand All @@ -426,11 +427,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}

// Returns copies of the data, so we can recycle the buffer.
user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response)
user, clientAuthMethod, clientAuthResponse, clientAttributes, err = l.parseClientHandshakePacket(c, false, response)
if err != nil {
log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err)
return
}
c.ConnectionAttributes = clientAttributes
c.recycleReadPacket()

if con, ok := c.conn.(*tls.Conn); ok {
Expand Down Expand Up @@ -688,18 +690,18 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, ch
}

// parseClientHandshakePacket parses the handshake sent by the client.
// Returns the username, auth method, auth data, error.
// Returns the username, auth method, auth data, connection attributes, error.
// The original data is not pointed at, and can be freed.
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) {
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, ConnectionAttributesMap, error) {
pos := 0

// Client flags, 4 bytes.
clientFlags, pos, ok := readUint32(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
}
if clientFlags&CapabilityClientProtocol41 == 0 {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
}

// Remember a subset of the capabilities, so we can use them
Expand All @@ -718,13 +720,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
// See doc.go for more information.
_, pos, ok = readUint32(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
}

// Character set. Need to handle it.
characterSet, pos, ok := readByte(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
}
c.CharacterSet = collations.ID(characterSet)

Expand All @@ -738,13 +740,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
c.conn = conn
c.bufferedReader.Reset(conn)
c.Capabilities |= CapabilityClientSSL
return "", "", nil, nil
return "", "", nil, nil, nil
}

// username
username, pos, ok := readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
}

// auth-response can have three forms.
Expand All @@ -753,29 +755,29 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
var l uint64
l, pos, ok = readLenEncInt(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
}
authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}

} else if clientFlags&CapabilityClientSecureConnection != 0 {
var l byte
l, pos, ok = readByte(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
}

authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}
} else {
a := ""
a, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}
authResponse = []byte(a)
}
Expand All @@ -785,7 +787,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
dbname := ""
dbname, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
}
c.schemaName = dbname
}
Expand All @@ -796,7 +798,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
var authMethodStr string
authMethodStr, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
}
// The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password
if authMethodStr != "" {
Expand All @@ -805,16 +807,20 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
}

// Decode connection attributes send by the client
var clientAttributes map[string]string
if clientFlags&CapabilityClientConnAttr != 0 {
if _, _, err := parseConnAttrs(data, pos); err != nil {
ca, _, err := parseConnAttrs(data, pos)
if err != nil {
log.Warningf("Decode connection attributes send by the client: %v", err)
}

clientAttributes = ca
}

return username, AuthMethodDescription(authMethod), authResponse, nil
return username, AuthMethodDescription(authMethod), authResponse, clientAttributes, nil
}

func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
func parseConnAttrs(data []byte, pos int) (ConnectionAttributesMap, int, error) {
var attrLen uint64

attrLen, pos, ok := readLenEncInt(data, pos)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgateproxy/mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (ph *proxyHandler) session(c *mysql.Conn) *vtgateconn.VTGateSession {
}

var err error
session, err = ph.proxy.NewSession(options)
session, err = ph.proxy.NewSession(options, c.ConnectionAttributes)
if err != nil {
log.Errorf("error creating new session for %s: %v", c.GetRawConn().RemoteAddr().String(), err)
}
Expand Down
10 changes: 7 additions & 3 deletions go/vt/vtgateproxy/vtgateproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package vtgateproxy
import (
"context"
"flag"
"fmt"
"io"
"time"

Expand Down Expand Up @@ -68,11 +69,16 @@ func (proxy *VTGateProxy) connect(ctx context.Context) error {
return nil
}

func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions) (*vtgateconn.VTGateSession, error) {
func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions, connectionAttributes map[string]string) (*vtgateconn.VTGateSession, error) {
if proxy.conn == nil {
return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected")
}

target, ok := connectionAttributes["target"]
if ok {
fmt.Printf("Creating new session from upstream provided target string: %v\n", target)
}

// XXX/demmer handle schemaName?
return proxy.conn.Session("", options), nil
}
Expand All @@ -95,8 +101,6 @@ func (proxy *VTGateProxy) Prepare(ctx context.Context, session *vtgateconn.VTGat
}

func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable) (qr *sqltypes.Result, err error) {
log.Infof("Execute %s", sql)

if proxy.conn == nil {
return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected")
}
Expand Down

0 comments on commit 6e74221

Please sign in to comment.