Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stash the connection attributes on the conn struct #152

Merged
merged 2 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ type Conn struct {
// It is set during the initial handshake.
UserData Getter

// ConnectionAttributes stores attributes set in the connection phase when
// attributes from the client are sent. This is arbitrary key/value pairs
// sent by the client.
ConnectionAttributes ConnectionAttributesMap

bufferedReader *bufio.Reader
flushTimer *time.Timer
header [packetHeaderSize]byte
Expand Down
4 changes: 4 additions & 0 deletions go/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ const (
// implemented authentication methods.
type AuthMethodDescription string

// Map of client key/value pairs sent by the client during
// the connection phase
type ConnectionAttributesMap map[string]string

// Supported auth forms.
const (
// MysqlNativePassword uses a salt and transmits a hash on the wire.
Expand Down
46 changes: 26 additions & 20 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
return
}
user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response)
user, clientAuthMethod, clientAuthResponse, clientAttributes, err := l.parseClientHandshakePacket(c, true, response)
if err != nil {
log.Errorf("Cannot parse client handshake response from %s: %v", c, err)
return
}
c.ConnectionAttributes = clientAttributes

c.recycleReadPacket()

Expand All @@ -371,11 +372,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}

// Returns copies of the data, so we can recycle the buffer.
user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response)
user, clientAuthMethod, clientAuthResponse, clientAttributes, err = l.parseClientHandshakePacket(c, false, response)
if err != nil {
log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err)
return
}
c.ConnectionAttributes = clientAttributes
c.recycleReadPacket()

if con, ok := c.conn.(*tls.Conn); ok {
Expand Down Expand Up @@ -636,18 +638,18 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en
}

// parseClientHandshakePacket parses the handshake sent by the client.
// Returns the username, auth method, auth data, error.
// Returns the username, auth method, auth data, connection attributes, error.
// The original data is not pointed at, and can be freed.
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) {
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, ConnectionAttributesMap, error) {
pos := 0

// Client flags, 4 bytes.
clientFlags, pos, ok := readUint32(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
}
if clientFlags&CapabilityClientProtocol41 == 0 {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
}

// Remember a subset of the capabilities, so we can use them
Expand All @@ -666,13 +668,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
// See doc.go for more information.
_, pos, ok = readUint32(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
}

// Character set. Need to handle it.
characterSet, pos, ok := readByte(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
}
c.CharacterSet = collations.ID(characterSet)

Expand All @@ -686,13 +688,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
c.conn = conn
c.bufferedReader.Reset(conn)
c.Capabilities |= CapabilityClientSSL
return "", "", nil, nil
return "", "", nil, nil, nil
}

// username
username, pos, ok := readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
}

// auth-response can have three forms.
Expand All @@ -701,29 +703,29 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
var l uint64
l, pos, ok = readLenEncInt(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
}
authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}

} else if clientFlags&CapabilityClientSecureConnection != 0 {
var l byte
l, pos, ok = readByte(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
}

authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}
} else {
a := ""
a, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}
authResponse = []byte(a)
}
Expand All @@ -733,7 +735,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
dbname := ""
dbname, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
}
c.schemaName = dbname
}
Expand All @@ -744,7 +746,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
var authMethodStr string
authMethodStr, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
}
// The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password
if authMethodStr != "" {
Expand All @@ -753,16 +755,20 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
}

// Decode connection attributes send by the client
var clientAttributes map[string]string
if clientFlags&CapabilityClientConnAttr != 0 {
if _, _, err := parseConnAttrs(data, pos); err != nil {
ca, _, err := parseConnAttrs(data, pos)
if err != nil {
log.Warningf("Decode connection attributes send by the client: %v", err)
}

clientAttributes = ca
}

return username, AuthMethodDescription(authMethod), authResponse, nil
return username, AuthMethodDescription(authMethod), authResponse, clientAttributes, nil
}

func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
func parseConnAttrs(data []byte, pos int) (ConnectionAttributesMap, int, error) {
var attrLen uint64

attrLen, pos, ok := readLenEncInt(data, pos)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgateproxy/mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (ph *proxyHandler) session(c *mysql.Conn) *vtgateconn.VTGateSession {
}

var err error
session, err = ph.proxy.NewSession(options)
session, err = ph.proxy.NewSession(options, c.ConnectionAttributes)
if err != nil {
log.Errorf("error creating new session for %s: %v", c.GetRawConn().RemoteAddr().String(), err)
}
Expand Down
10 changes: 7 additions & 3 deletions go/vt/vtgateproxy/vtgateproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package vtgateproxy
import (
"context"
"flag"
"fmt"
"io"
"time"

Expand Down Expand Up @@ -68,11 +69,16 @@ func (proxy *VTGateProxy) connect(ctx context.Context) error {
return nil
}

func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions) (*vtgateconn.VTGateSession, error) {
func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions, connectionAttributes map[string]string) (*vtgateconn.VTGateSession, error) {
if proxy.conn == nil {
return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected")
}

target, ok := connectionAttributes["target"]
if ok {
fmt.Printf("Creating new session from upstream provided target string: %v\n", target)
}

// XXX/demmer handle schemaName?
return proxy.conn.Session("", options), nil
}
Expand All @@ -95,8 +101,6 @@ func (proxy *VTGateProxy) Prepare(ctx context.Context, session *vtgateconn.VTGat
}

func (proxy *VTGateProxy) Execute(ctx context.Context, session *vtgateconn.VTGateSession, sql string, bindVariables map[string]*querypb.BindVariable) (qr *sqltypes.Result, err error) {
log.Infof("Execute %s", sql)

if proxy.conn == nil {
return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected")
}
Expand Down
Loading