Skip to content

Commit

Permalink
backend, net: add more error sources (#407)
Browse files Browse the repository at this point in the history
Signed-off-by: xhe <xw897002528@gmail.com>
Co-authored-by: xhe <xw897002528@gmail.com>
  • Loading branch information
djshow832 and xhebox authored Dec 1, 2023
1 parent 8712ca0 commit f22f82b
Show file tree
Hide file tree
Showing 22 changed files with 332 additions and 235 deletions.
5 changes: 0 additions & 5 deletions pkg/manager/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ import (
"time"

glist "github.com/bahlo/generic-list-go"
"github.com/pingcap/tiproxy/lib/util/errors"
)

var (
ErrNoInstanceToSelect = errors.New("no instances to route")
)

// ConnEventReceiver receives connection events.
Expand Down
73 changes: 38 additions & 35 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ import (
"go.uber.org/zap"
)

var (
ErrCapabilityNegotiation = errors.New("capability negotiation failed")
)

const unknownAuthPlugin = "auth_unknown_plugin"
const requiredFrontendCaps = pnet.ClientProtocol41
const defRequiredBackendCaps = pnet.ClientDeprecateEOF
Expand Down Expand Up @@ -76,10 +72,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)
return errors.Wrapf(ErrBackendCap, "require %s from backend", requiredBackendCaps^commonCaps)
}
if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) {
return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg)
return ErrBackendNoTLS
}
return nil
}
Expand All @@ -106,7 +102,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
frontendCapability := pnet.Capability(binary.LittleEndian.Uint32(pkt))
if isSSL {
if _, err = clientIO.ServerTLSHandshake(frontendTLSConfig); err != nil {
return pnet.WrapUserError(err, err.Error())
return errors.Wrap(ErrClientHandshake, err)
}
pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp()
if err != nil {
Expand All @@ -125,7 +121,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil {
return writeErr
}
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
return errors.Wrapf(ErrClientCap, "require %s from frontend", requiredFrontendCaps&^commonCaps)
}
commonCaps := frontendCapability & proxyCapability
if frontendCapability^commonCaps != 0 {
Expand All @@ -147,10 +143,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
if errors.As(err, &warning) {
logger.Warn("parse handshake response encounters error", zap.Error(err))
} else if err != nil {
return pnet.WrapUserError(err, parsePktErrMsg)
return err
}
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
return pnet.WrapUserError(err, err.Error())
return errors.Wrap(ErrProxyErr, err)
}
auth.user = clientResp.User
auth.dbname = clientResp.DB
Expand All @@ -163,29 +159,28 @@ RECONNECT:
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp)
if err != nil {
return pnet.WrapUserError(err, connectErrMsg)
return err
}
backendIO.ResetSequence()

// write proxy header
if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil {
return pnet.WrapUserError(err, handshakeErrMsg)
return err
}

// read backend initial handshake
serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
if IsMySQLError(err) {
if pnet.IsMySQLError(err) {
if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil {
err = writeErr
return writeErr
}
return err
}
return pnet.WrapUserError(err, handshakeErrMsg)
return err
}

if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return pnet.WrapUserError(err, capabilityErrMsg)
return err
}

if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
Expand All @@ -207,7 +202,7 @@ RECONNECT:
// Copy the auth data so that the backend can set correct `using password` in the error message.
unknownAuthPlugin, clientResp.AuthData, 0,
); err != nil {
return pnet.WrapUserError(err, handshakeErrMsg)
return err
}

// forward other packets
Expand All @@ -220,16 +215,18 @@ loop:
// tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
// tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) {
return pnet.WrapUserError(err, checkPPV2ErrMsg)
return errors.Wrap(ErrBackendPPV2, err)
}
return err
}
var packetErr error
var packetErr *mysql.MyError
if serverPkt[0] == pnet.ErrHeader.Byte() {
packetErr = pnet.ParseErrorPacket(serverPkt)
if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*mysql.MyError)) {
logger.Warn("handle handshake error, start reconnect", zap.Error(err))
backendIO.Close()
if handshakeHandler.HandleHandshakeErr(cctx, packetErr) {
logger.Warn("handle handshake error, start reconnect", zap.Error(packetErr))
if closeErr := backendIO.Close(); closeErr != nil {
logger.Warn("close backend error", zap.Error(closeErr))
}
goto RECONNECT
}
}
Expand All @@ -238,17 +235,17 @@ loop:
return err
}
if packetErr != nil {
return packetErr
return errors.Wrap(ErrClientAuthFail, packetErr)
}

pktIdx++
switch serverPkt[0] {
case pnet.OKHeader.Byte():
if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil {
return err
return errors.Wrap(ErrClientHandshake, err)
}
if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
return err
return errors.Wrap(ErrBackendHandshake, err)
}
return nil
default: // mysql.AuthSwitchRequest, ShaCommand
Expand Down Expand Up @@ -276,7 +273,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) {

func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error {
if len(sessionToken) == 0 {
return errors.New("session token is empty")
return errors.Wrapf(ErrBackendHandshake, "session token is empty")
}

// write proxy header
Expand All @@ -301,17 +298,20 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
}

if err = auth.handleSecondAuthResult(backendIO); err == nil {
return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel)
if err = setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
return errors.Wrap(ErrBackendHandshake, err)
}
}
return err
return errors.Wrap(ErrBackendHandshake, err)
}

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
if serverPkt, err = backendIO.ReadPacket(); err != nil {
err = errors.Wrap(ErrBackendHandshake, err)
return
}
if pnet.IsErrorPacket(serverPkt[0]) {
err = pnet.ParseErrorPacket(serverPkt)
err = errors.Wrap(ErrBackendHandshake, pnet.ParseErrorPacket(serverPkt))
return
}
capability, _, _ = pnet.ParseInitialHandshake(serverPkt)
Expand Down Expand Up @@ -346,7 +346,7 @@ func (auth *Authenticator) writeAuthHandshake(
var enableTLS bool
if auth.requireBackendTLS {
if backendTLSConfig == nil {
return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg)
return ErrProxyNoTLS
}
enableTLS = true
} else {
Expand All @@ -358,7 +358,7 @@ func (auth *Authenticator) writeAuthHandshake(
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
return errors.Wrap(ErrBackendHandshake, err)
}
// Send TLS / SSL request packet. The server must have supported TLS.
tcfg := backendTLSConfig.Clone()
Expand All @@ -370,15 +370,18 @@ func (auth *Authenticator) writeAuthHandshake(
if err := backendIO.ClientTLSHandshake(tcfg); err != nil {
// tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet
// tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet
return pnet.WrapUserError(err, checkPPV2ErrMsg)
return errors.Wrap(ErrBackendPPV2, err)
}
} else {
resp.Capability &= ^pnet.ClientSSL
pkt = pnet.MakeHandshakeResponse(resp)
}

// write handshake resp
return backendIO.WritePacket(pkt, true)
if err := backendIO.WritePacket(pkt, true); err != nil {
return errors.Wrap(ErrBackendHandshake, err)
}
return nil
}

func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error {
Expand All @@ -393,7 +396,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(data)
default: // mysql.AuthSwitchRequest, ShaCommand:
return errors.Errorf("read unexpected command: %#x", data[0])
return errors.Wrapf(mysql.ErrMalformPacket, "read unexpected command: %#x", data[0])
}
}

Expand Down
38 changes: 22 additions & 16 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -70,10 +69,14 @@ func TestUnsupportedCapability(t *testing.T) {
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) {
if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
} else if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps {
require.ErrorIs(t, ts.mp.err, ErrClientCap)
require.Nil(t, ErrToClient(ts.mp.err))
require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err))
} else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps {
require.ErrorIs(t, ts.mp.err, ErrBackendCap)
require.Equal(t, ErrBackendCap, ErrToClient(ts.mp.err))
require.Equal(t, SrcBackendHandshake, Error2Source(ts.mp.err))
} else {
require.NoError(t, ts.mc.err)
require.NoError(t, ts.mp.err)
Expand Down Expand Up @@ -311,31 +314,35 @@ func TestAuthFail(t *testing.T) {
ts, clean := newTestSuite(t, tc, cfg)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
require.Equal(t, len(ts.mc.authData), len(ts.mb.authData))
require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err))
})
clean()
}
}

func TestRequireBackendTLS(t *testing.T) {
tests := []struct {
cfg cfgOverrider
errMsg string
cfg cfgOverrider
err error
src ErrorSource
}{
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.proxyConfig.backendTLSConfig = nil
cfg.backendConfig.capability |= pnet.ClientSSL
},
errMsg: requireProxyTLSErrMsg,
err: ErrProxyNoTLS,
src: SrcProxyErr,
},
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.backendConfig.tlsConfig = nil
cfg.backendConfig.capability &= ^pnet.ClientSSL
},
errMsg: requireTiDBTLSErrMsg,
err: ErrBackendNoTLS,
src: SrcBackendHandshake,
},
{
cfg: func(cfg *testConfig) {
Expand All @@ -351,10 +358,9 @@ func TestRequireBackendTLS(t *testing.T) {
for _, tt := range tests {
ts, clean := newTestSuite(t, tc, tt.cfg)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
if len(tt.errMsg) > 0 {
var userError *pnet.UserError
require.True(t, errors.As(ts.mp.err, &userError))
require.Equal(t, tt.errMsg, userError.UserMsg())
if tt.err != nil {
require.ErrorIs(t, ts.mp.err, tt.err)
require.Equal(t, tt.src, Error2Source(ts.mp.err))
} else {
require.NoError(t, ts.mp.err)
}
Expand Down Expand Up @@ -401,9 +407,9 @@ func TestProxyProtocol(t *testing.T) {
// TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable.
// So when backend enables proxy-protocol and proxy disables it, it still works well.
if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol {
var userError *pnet.UserError
require.True(t, errors.As(ts.mp.err, &userError))
require.Equal(t, checkPPV2ErrMsg, userError.UserMsg())
err := ErrToClient(ts.mp.err)
require.Equal(t, ErrBackendPPV2, err)
require.Equal(t, SrcBackendHandshake, Error2Source(err))
} else {
require.NoError(t, ts.mp.err)
}
Expand Down
Loading

0 comments on commit f22f82b

Please sign in to comment.