Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend, net: add more error sources #407

Merged
merged 5 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pkg/manager/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ import (
"time"

glist "github.com/bahlo/generic-list-go"
"github.com/pingcap/tiproxy/lib/util/errors"
)

var (
ErrNoInstanceToSelect = errors.New("no instances to route")
)

// ConnEventReceiver receives connection events.
Expand Down
73 changes: 38 additions & 35 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ import (
"go.uber.org/zap"
)

var (
ErrCapabilityNegotiation = errors.New("capability negotiation failed")
)

const unknownAuthPlugin = "auth_unknown_plugin"
const requiredFrontendCaps = pnet.ClientProtocol41
const defRequiredBackendCaps = pnet.ClientDeprecateEOF
Expand Down Expand Up @@ -76,10 +72,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)
return errors.Wrapf(ErrBackendCap, "require %s from backend", requiredBackendCaps^commonCaps)
}
if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) {
return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg)
return ErrBackendNoTLS
}
return nil
}
Expand All @@ -106,7 +102,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
frontendCapability := pnet.Capability(binary.LittleEndian.Uint32(pkt))
if isSSL {
if _, err = clientIO.ServerTLSHandshake(frontendTLSConfig); err != nil {
return pnet.WrapUserError(err, err.Error())
return errors.Wrap(ErrClientHandshake, err)
}
pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp()
if err != nil {
Expand All @@ -125,7 +121,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil {
return writeErr
}
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
return errors.Wrapf(ErrClientCap, "require %s from frontend", requiredFrontendCaps&^commonCaps)
}
commonCaps := frontendCapability & proxyCapability
if frontendCapability^commonCaps != 0 {
Expand All @@ -147,10 +143,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
if errors.As(err, &warning) {
logger.Warn("parse handshake response encounters error", zap.Error(err))
} else if err != nil {
return pnet.WrapUserError(err, parsePktErrMsg)
return err
}
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
return pnet.WrapUserError(err, err.Error())
return errors.Wrap(ErrProxyErr, err)
}
auth.user = clientResp.User
auth.dbname = clientResp.DB
Expand All @@ -163,29 +159,28 @@ RECONNECT:
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp)
if err != nil {
return pnet.WrapUserError(err, connectErrMsg)
return err
}
backendIO.ResetSequence()

// write proxy header
if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil {
return pnet.WrapUserError(err, handshakeErrMsg)
return err
}

// read backend initial handshake
serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
if IsMySQLError(err) {
if pnet.IsMySQLError(err) {
if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil {
err = writeErr
return writeErr
}
return err
}
return pnet.WrapUserError(err, handshakeErrMsg)
return err
}

if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return pnet.WrapUserError(err, capabilityErrMsg)
return err
}

if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
Expand All @@ -207,7 +202,7 @@ RECONNECT:
// Copy the auth data so that the backend can set correct `using password` in the error message.
unknownAuthPlugin, clientResp.AuthData, 0,
); err != nil {
return pnet.WrapUserError(err, handshakeErrMsg)
return err
}

// forward other packets
Expand All @@ -220,16 +215,18 @@ loop:
// tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
// tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) {
return pnet.WrapUserError(err, checkPPV2ErrMsg)
return errors.Wrap(ErrBackendPPV2, err)
}
return err
}
var packetErr error
var packetErr *mysql.MyError
if serverPkt[0] == pnet.ErrHeader.Byte() {
packetErr = pnet.ParseErrorPacket(serverPkt)
if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*mysql.MyError)) {
logger.Warn("handle handshake error, start reconnect", zap.Error(err))
backendIO.Close()
if handshakeHandler.HandleHandshakeErr(cctx, packetErr) {
logger.Warn("handle handshake error, start reconnect", zap.Error(packetErr))
if closeErr := backendIO.Close(); closeErr != nil {
logger.Warn("close backend error", zap.Error(closeErr))
}
goto RECONNECT
}
}
Expand All @@ -238,17 +235,17 @@ loop:
return err
}
if packetErr != nil {
return packetErr
return errors.Wrap(ErrClientAuthFail, packetErr)
}

pktIdx++
switch serverPkt[0] {
case pnet.OKHeader.Byte():
if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil {
return err
return errors.Wrap(ErrClientHandshake, err)
}
if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
return err
return errors.Wrap(ErrBackendHandshake, err)
}
return nil
default: // mysql.AuthSwitchRequest, ShaCommand
Expand Down Expand Up @@ -276,7 +273,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) {

func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error {
if len(sessionToken) == 0 {
return errors.New("session token is empty")
return errors.Wrapf(ErrBackendHandshake, "session token is empty")
}

// write proxy header
Expand All @@ -301,17 +298,20 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
}

if err = auth.handleSecondAuthResult(backendIO); err == nil {
return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel)
if err = setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
return errors.Wrap(ErrBackendHandshake, err)
}
}
return err
return errors.Wrap(ErrBackendHandshake, err)
}

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
if serverPkt, err = backendIO.ReadPacket(); err != nil {
err = errors.Wrap(ErrBackendHandshake, err)
return
}
if pnet.IsErrorPacket(serverPkt[0]) {
err = pnet.ParseErrorPacket(serverPkt)
err = errors.Wrap(ErrBackendHandshake, pnet.ParseErrorPacket(serverPkt))
return
}
capability, _, _ = pnet.ParseInitialHandshake(serverPkt)
Expand Down Expand Up @@ -346,7 +346,7 @@ func (auth *Authenticator) writeAuthHandshake(
var enableTLS bool
if auth.requireBackendTLS {
if backendTLSConfig == nil {
return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg)
return ErrProxyNoTLS
}
enableTLS = true
} else {
Expand All @@ -358,7 +358,7 @@ func (auth *Authenticator) writeAuthHandshake(
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
return errors.Wrap(ErrBackendHandshake, err)
}
// Send TLS / SSL request packet. The server must have supported TLS.
tcfg := backendTLSConfig.Clone()
Expand All @@ -370,15 +370,18 @@ func (auth *Authenticator) writeAuthHandshake(
if err := backendIO.ClientTLSHandshake(tcfg); err != nil {
// tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet
// tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet
return pnet.WrapUserError(err, checkPPV2ErrMsg)
return errors.Wrap(ErrBackendPPV2, err)
}
} else {
resp.Capability &= ^pnet.ClientSSL
pkt = pnet.MakeHandshakeResponse(resp)
}

// write handshake resp
return backendIO.WritePacket(pkt, true)
if err := backendIO.WritePacket(pkt, true); err != nil {
return errors.Wrap(ErrBackendHandshake, err)
}
return nil
}

func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error {
Expand All @@ -393,7 +396,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(data)
default: // mysql.AuthSwitchRequest, ShaCommand:
return errors.Errorf("read unexpected command: %#x", data[0])
return errors.Wrapf(mysql.ErrMalformPacket, "read unexpected command: %#x", data[0])
}
}

Expand Down
38 changes: 22 additions & 16 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -70,10 +69,14 @@ func TestUnsupportedCapability(t *testing.T) {
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) {
if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
} else if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps {
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps {
require.ErrorIs(t, ts.mp.err, ErrClientCap)
require.Nil(t, ErrToClient(ts.mp.err))
require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err))
} else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps {
require.ErrorIs(t, ts.mp.err, ErrBackendCap)
require.Equal(t, ErrBackendCap, ErrToClient(ts.mp.err))
require.Equal(t, SrcBackendHandshake, Error2Source(ts.mp.err))
} else {
require.NoError(t, ts.mc.err)
require.NoError(t, ts.mp.err)
Expand Down Expand Up @@ -311,31 +314,35 @@ func TestAuthFail(t *testing.T) {
ts, clean := newTestSuite(t, tc, cfg)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
require.Equal(t, len(ts.mc.authData), len(ts.mb.authData))
require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err))
})
clean()
}
}

func TestRequireBackendTLS(t *testing.T) {
tests := []struct {
cfg cfgOverrider
errMsg string
cfg cfgOverrider
err error
src ErrorSource
}{
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.proxyConfig.backendTLSConfig = nil
cfg.backendConfig.capability |= pnet.ClientSSL
},
errMsg: requireProxyTLSErrMsg,
err: ErrProxyNoTLS,
src: SrcProxyErr,
},
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.backendConfig.tlsConfig = nil
cfg.backendConfig.capability &= ^pnet.ClientSSL
},
errMsg: requireTiDBTLSErrMsg,
err: ErrBackendNoTLS,
src: SrcBackendHandshake,
},
{
cfg: func(cfg *testConfig) {
Expand All @@ -351,10 +358,9 @@ func TestRequireBackendTLS(t *testing.T) {
for _, tt := range tests {
ts, clean := newTestSuite(t, tc, tt.cfg)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
if len(tt.errMsg) > 0 {
var userError *pnet.UserError
require.True(t, errors.As(ts.mp.err, &userError))
require.Equal(t, tt.errMsg, userError.UserMsg())
if tt.err != nil {
require.ErrorIs(t, ts.mp.err, tt.err)
require.Equal(t, tt.src, Error2Source(ts.mp.err))
} else {
require.NoError(t, ts.mp.err)
}
Expand Down Expand Up @@ -401,9 +407,9 @@ func TestProxyProtocol(t *testing.T) {
// TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable.
// So when backend enables proxy-protocol and proxy disables it, it still works well.
if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol {
var userError *pnet.UserError
require.True(t, errors.As(ts.mp.err, &userError))
require.Equal(t, checkPPV2ErrMsg, userError.UserMsg())
err := ErrToClient(ts.mp.err)
require.Equal(t, ErrBackendPPV2, err)
require.Equal(t, SrcBackendHandshake, Error2Source(err))
} else {
require.NoError(t, ts.mp.err)
}
Expand Down
Loading
Loading