From 2cc8f0fa0740fd64ca1c71d976a8831aa41ec531 Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Tue, 21 Nov 2023 12:38:07 +0800 Subject: [PATCH 1/4] add more error sources --- pkg/proxy/backend/authenticator.go | 18 ++-- pkg/proxy/backend/backend_conn_mgr.go | 46 ++++------ pkg/proxy/backend/error.go | 115 ++++++++++++++++++++++++- pkg/proxy/backend/handshake_handler.go | 50 ++--------- pkg/proxy/client/client_conn.go | 6 +- pkg/proxy/net/mysql.go | 4 +- pkg/proxy/net/packetio.go | 12 +-- pkg/proxy/net/packetio_test.go | 4 +- pkg/proxy/net/tls.go | 2 +- 9 files changed, 159 insertions(+), 98 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index f5acb6e2..6ae1ab3b 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -76,10 +76,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili // The error cannot be sent to the client because the client only expects an initial handshake packet. // The only way is to log it and disconnect. logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps)) - return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps) + return pnet.WrapUserError(errors.Wrap(ErrBackendHandshake, errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)), capabilityErrMsg) } if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) { - return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg) + return pnet.WrapUserError(errors.Wrap(ErrBackendHandshake, errors.New("backend doesn't enable TLS")), requireTiDBTLSErrMsg) } return nil } @@ -125,7 +125,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil { return writeErr } - return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps) + return errors.Wrap(ErrClientHandshake, errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)) } commonCaps := frontendCapability & proxyCapability if frontendCapability^commonCaps != 0 { @@ -185,7 +185,7 @@ RECONNECT: } if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { - return pnet.WrapUserError(err, capabilityErrMsg) + return err } if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 { @@ -224,12 +224,14 @@ loop: } return err } - var packetErr error + var packetErr *mysql.MyError if serverPkt[0] == pnet.ErrHeader.Byte() { packetErr = pnet.ParseErrorPacket(serverPkt) - if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*mysql.MyError)) { - logger.Warn("handle handshake error, start reconnect", zap.Error(err)) - backendIO.Close() + if handshakeHandler.HandleHandshakeErr(cctx, packetErr) { + logger.Warn("handle handshake error, start reconnect", zap.Error(packetErr)) + if closeErr := backendIO.Close(); closeErr != nil { + logger.Warn("close backend error", zap.Error(closeErr)) + } goto RECONNECT } } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 5123734d..85862c65 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -148,7 +148,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler // There are 2 types of signals, which may be sent concurrently. signalReceived: make(chan signalType, signalTypeNums), redirectResCh: make(chan *redirectResult, 1), - quitSource: SrcClientQuit, + quitSource: SrcNone, } mgr.SetValue(ConnContextKeyConnID, connectionID) return mgr @@ -170,13 +170,11 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.clientIO = clientIO err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) if err != nil { - mgr.setQuitSourceByErr(err) - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err, Error2Source(err)) clientIO.WriteUserError(err) return err } - mgr.resetQuitSource() - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil, SrcNone) mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) @@ -235,8 +233,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { origErr = err - mgr.setQuitSourceByErr(err) - mgr.handshakeHandler.OnHandshake(cctx, addr, err) + mgr.handshakeHandler.OnHandshake(cctx, addr, err, Error2Source(err)) }, ) cancel() @@ -431,7 +428,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { // If the backend connection is closed, also close the client connection. // Otherwise, if the client is idle, the mgr will keep retrying. if errors.Is(rs.err, net.ErrClosed) || pnet.IsDisconnectError(rs.err) || errors.Is(rs.err, os.ErrDeadlineExceeded) { - mgr.quitSource = SrcBackendQuit + mgr.quitSource = SrcBackendNetwork if ignoredErr := mgr.clientIO.GracefulClose(); ignoredErr != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Error(ignoredErr)) } @@ -442,12 +439,10 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { return } - defer mgr.resetQuitSource() var cn net.Conn cn, rs.err = net.DialTimeout("tcp", rs.to, DialTimeout) if rs.err != nil { - mgr.quitSource = SrcBackendQuit - mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err) + mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err, SrcBackendNetwork) return } newBackendIO := pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn)) @@ -455,8 +450,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newBackendIO, sessionStates) } else { - mgr.setQuitSourceByErr(rs.err) - mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err) + mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err, Error2Source(rs.err)) } if rs.err != nil { if ignoredErr := newBackendIO.Close(); ignoredErr != nil && !pnet.IsDisconnectError(ignoredErr) { @@ -469,7 +463,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } mgr.backendIO.Store(newBackendIO) mgr.setKeepAlive(mgr.config.HealthyKeepAlive) - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil, SrcNone) } // The original db in the auth info may be dropped during the session, so we need to authenticate with the current db. @@ -553,7 +547,7 @@ func (mgr *BackendConnManager) checkBackendActive() { if !backendIO.IsPeerActive() { mgr.logger.Info("backend connection is closed, close client connection", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Stringer("backend_addr", backendIO.RemoteAddr())) - mgr.quitSource = SrcBackendQuit + mgr.quitSource = SrcBackendNetwork if err := mgr.clientIO.GracefulClose(); err != nil { mgr.logger.Warn("graceful close client IO error", zap.Stringer("client_addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } @@ -627,7 +621,7 @@ func (mgr *BackendConnManager) Close() error { } mgr.wg.Wait() - handErr := mgr.handshakeHandler.OnConnClose(mgr) + handErr := mgr.handshakeHandler.OnConnClose(mgr, mgr.quitSource) var connErr error var addr string @@ -677,26 +671,16 @@ func (mgr *BackendConnManager) setKeepAlive(cfg config.KeepAlive) { } } -// quitSource will be read by OnHandshake and OnConnClose, so setQuitSourceByErr should be called before them. func (mgr *BackendConnManager) setQuitSourceByErr(err error) { - // Do not update the source if err is nil. It may be already be set. if err == nil { return } - if errors.Is(err, ErrBackendConn) { - mgr.quitSource = SrcBackendQuit - } else if IsMySQLError(err) { - mgr.quitSource = SrcClientErr - } else if !errors.Is(err, ErrClientConn) { - mgr.quitSource = SrcProxyErr + // The source may be already be set. + // E.g. quitSource is set before TiProxy shuts down and client connection error is caused by shutdown instead of network. + if mgr.quitSource != SrcNone { + return } -} - -func (mgr *BackendConnManager) resetQuitSource() { - // SrcClientQuit is by default. - // Sometimes ErrClientConn is caused by GracefulClose and the quitSource is already set. - // Error maybe set during handshake for OnHandshake. If handshake finally succeeds, we reset it. - mgr.quitSource = SrcClientQuit + mgr.quitSource = Error2Source(err) } func (mgr *BackendConnManager) UpdateLogger(fields ...zap.Field) { diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 4097126a..38e44fa6 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -4,7 +4,9 @@ package backend import ( + gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tiproxy/lib/util/errors" + pnet "github.com/pingcap/tiproxy/pkg/proxy/net" ) const ( @@ -18,6 +20,115 @@ const ( ) var ( - ErrClientConn = errors.New("this is an error from client") - ErrBackendConn = errors.New("this is an error from backend") + ErrClientConn = errors.New("read or write client connection fail") + ErrClientHandshake = errors.New("handshake with client fail") + ErrBackendConn = errors.New("read or write backend connection fail") + ErrBackendHandshake = errors.New("handshake with backend fail") ) + +type SourceComp int + +const ( + CompNone SourceComp = iota + CompClient + CompProxy + CompBackend +) + +type ErrorSource int + +const ( + // SrcNone includes: succeed for OnHandshake; client normally quit for OnConnClose + SrcNone ErrorSource = iota + // SrcClientNetwork includes: EOF; reset by peer; connection refused; TLS handshake fails + SrcClientNetwork + // SrcClientHandshake includes: client capability unsupported + SrcClientHandshake + // SrcClientSQLErr includes: backend returns auth fail; SQL error + SrcClientSQLErr + // SrcProxyQuit includes: proxy graceful shutdown + SrcProxyQuit + // SrcProxyMalformed includes: malformed packet; invalid sequence + SrcProxyMalformed + // SrcProxyGetBackend includes: no backends + SrcProxyGetBackend + // SrcProxyErr includes: HandshakeHandler returns error; proxy disables TLS + SrcProxyErr + // SrcBackendNetwork includes: EOF; reset by peer; connection refused; TLS handshake fails + SrcBackendNetwork + // SrcBackendHandshake includes: backend capability unsupported; backend disables TLS + SrcBackendHandshake +) + +// Error2Source returns the ErrorSource by the error. +func Error2Source(err error) ErrorSource { + switch { + case err == nil: + return SrcNone + case errors.Is(err, pnet.ErrInvalidSequence) || errors.Is(err, gomysql.ErrMalformPacket): + // We assume the clients and TiDB are right and treat it as TiProxy bugs. + // ErrInvalidSequence may be wrapped with ErrClientConn or ErrBackendConn, so put it before other conditions. + return SrcProxyErr + case errors.Is(err, ErrClientConn): + return SrcClientNetwork + case errors.Is(err, ErrBackendConn): + return SrcBackendNetwork + case errors.Is(err, ErrClientHandshake): + return SrcClientHandshake + case errors.Is(err, ErrBackendHandshake): + return SrcBackendHandshake + case IsMySQLError(err): + return SrcClientSQLErr + default: + return SrcProxyErr + } +} + +func (es ErrorSource) String() string { + switch es { + case SrcNone: + return "ok" + case SrcClientNetwork: + return "client network break" + case SrcClientHandshake: + return "client handshake fail" + case SrcClientSQLErr: + return "client SQL error" + case SrcProxyQuit: + return "proxy shutdown" + case SrcProxyMalformed: + return "malformed packet" + case SrcProxyGetBackend: + return "proxy get backend fail" + case SrcProxyErr: + return "proxy error" + case SrcBackendNetwork: + return "backend network break" + case SrcBackendHandshake: + return "backend handshake fail" + } + return "unknown" +} + +// GetSourceComp returns which component does this error belong to. +func (es ErrorSource) GetSourceComp() SourceComp { + switch es { + case SrcClientNetwork, SrcClientHandshake, SrcClientSQLErr: + return CompClient + case SrcProxyQuit, SrcProxyMalformed, SrcProxyGetBackend, SrcProxyErr: + return CompProxy + case SrcBackendNetwork, SrcBackendHandshake: + return CompBackend + default: + return CompNone + } +} + +// Normal returns whether this error source is expected. +func (es ErrorSource) Normal() bool { + switch es { + case SrcNone, SrcProxyQuit: + return true + } + return false +} diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index cb2665ac..9cc8e4e0 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -12,6 +12,8 @@ import ( "go.uber.org/zap" ) +// Interfaces in this file are used for the serverless tier. + // Context keys. type ConnContextKey string @@ -21,41 +23,6 @@ const ( ConnContextKeyConnAddr ConnContextKey = "conn-addr" ) -type ErrorSource int - -const ( - // SrcClientQuit includes: client quit; bad client conn - SrcClientQuit ErrorSource = iota - // SrcClientErr includes: wrong password; mal format packet - SrcClientErr - // SrcProxyQuit includes: proxy graceful shutdown - SrcProxyQuit - // SrcProxyErr includes: cannot get backend list; capability negotiation - SrcProxyErr - // SrcBackendQuit includes: backend quit - SrcBackendQuit - // SrcBackendErr is reserved - SrcBackendErr -) - -func (es ErrorSource) String() string { - switch es { - case SrcClientQuit: - return "client quit" - case SrcClientErr: - return "client error" - case SrcProxyQuit: - return "proxy shutdown" - case SrcProxyErr: - return "proxy error" - case SrcBackendQuit: - return "backend quit" - case SrcBackendErr: - return "backend error" - } - return "unknown" -} - var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) var _ HandshakeHandler = (*CustomHandshakeHandler)(nil) @@ -64,7 +31,6 @@ type ConnContext interface { ServerAddr() string ClientInBytes() uint64 ClientOutBytes() uint64 - QuitSource() ErrorSource UpdateLogger(fields ...zap.Field) SetValue(key, val any) Value(key any) any @@ -74,8 +40,8 @@ type HandshakeHandler interface { HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool // return true means retry connect GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) - OnHandshake(ctx ConnContext, to string, err error) - OnConnClose(ctx ConnContext) error + OnHandshake(ctx ConnContext, to string, err error, src ErrorSource) + OnConnClose(ctx ConnContext, src ErrorSource) error OnTraffic(ctx ConnContext) GetCapability() pnet.Capability GetServerVersion() string @@ -111,13 +77,13 @@ func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Ha return ns.GetRouter(), nil } -func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error) { +func (handler *DefaultHandshakeHandler) OnHandshake(ConnContext, string, error, ErrorSource) { } func (handler *DefaultHandshakeHandler) OnTraffic(ConnContext) { } -func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext) error { +func (handler *DefaultHandshakeHandler) OnConnClose(ConnContext, ErrorSource) error { return nil } @@ -156,7 +122,7 @@ func (h *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Handshake return nil, errors.New("no router") } -func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err error) { +func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err error, src ErrorSource) { if h.onHandshake != nil { h.onHandshake(ctx, addr, err) } @@ -168,7 +134,7 @@ func (h *CustomHandshakeHandler) OnTraffic(ctx ConnContext) { } } -func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext) error { +func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext, src ErrorSource) error { if h.onConnClose != nil { return h.onConnClose(ctx) } diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index a2fc4f50..b66afe16 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -57,11 +57,9 @@ func (cc *ClientConnection) Run(ctx context.Context) { clean: src := cc.connMgr.QuitSource() - switch src { - case backend.SrcClientQuit, backend.SrcClientErr, backend.SrcProxyQuit: - default: + if !src.Normal() { fields := cc.connMgr.ConnInfo() - fields = append(fields, zap.Stringer("quit source", src), zap.Error(err)) + fields = append(fields, zap.Stringer("quit_source", src), zap.Error(err)) cc.logger.Warn(msg, fields...) } } diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index d9b56f56..74a98ccd 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -32,7 +32,7 @@ func ParseInitialHandshake(data []byte) (Capability, uint64, string) { // skip min version serverVersion := string(data[1 : 1+bytes.IndexByte(data[1:], 0)]) pos := 1 + len(serverVersion) + 1 - connid := uint32(binary.LittleEndian.Uint32(data[pos : pos+4])) + connid := binary.LittleEndian.Uint32(data[pos : pos+4]) // skip salt first part // skip filter pos += 4 + 8 + 1 @@ -398,7 +398,7 @@ func ParseOKPacket(data []byte) uint16 { } // ParseErrorPacket transforms an error packet into a MyError object. -func ParseErrorPacket(data []byte) error { +func ParseErrorPacket(data []byte) *gomysql.MyError { e := new(gomysql.MyError) pos := 1 e.Code = binary.LittleEndian.Uint16(data[pos:]) diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 8b5013b1..201a6b33 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -333,7 +333,7 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first for { header, err := p.readWriter.Peek(5) if err != nil { - return errors.Wrap(ErrReadConn, err) + return p.wrapErr(errors.Wrap(ErrReadConn, err)) } length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 end, needData := isEnd(header[4], length) @@ -343,30 +343,30 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first // TODO: allocate a buffer from pool and return the buffer after `process`. data, err = p.ReadPacket() if err != nil { - return errors.Wrap(ErrReadConn, err) + return p.wrapErr(errors.Wrap(ErrReadConn, err)) } if err := dest.WritePacket(data, false); err != nil { - return errors.Wrap(ErrWriteConn, err) + return p.wrapErr(errors.Wrap(ErrWriteConn, err)) } } else { for { sequence, pktSequence := header[3], p.readWriter.Sequence() if sequence != pktSequence { - return ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence) + return p.wrapErr(ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)) } p.readWriter.SetSequence(sequence + 1) // Sequence may be different (e.g. with compression) so we can't just copy the data to the destination. dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1) p.limitReader.N = int64(length + 4) if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil { - return errors.Wrap(ErrRelayConn, err) + return p.wrapErr(errors.Wrap(ErrRelayConn, err)) } // For large packets, continue. if length < MaxPayloadLen { break } if header, err = p.readWriter.Peek(4); err != nil { - return errors.Wrap(ErrReadConn, err) + return p.wrapErr(errors.Wrap(ErrReadConn, err)) } length = int(header[0]) | int(header[1])<<8 | int(header[2])<<16 } diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 100ae4d1..4c092437 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -265,13 +265,13 @@ func TestPredefinedPacket(t *testing.T) { func(t *testing.T, cli *PacketIO) { data, err := cli.ReadPacket() require.NoError(t, err) - merr := ParseErrorPacket(data).(*mysql.MyError) + merr := ParseErrorPacket(data) require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code) require.Equal(t, "Unknown error", merr.Message) data, err = cli.ReadPacket() require.NoError(t, err) - merr = ParseErrorPacket(data).(*mysql.MyError) + merr = ParseErrorPacket(data) require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code) require.Equal(t, "test error", merr.Message) diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index 4ad416ec..9303ecf3 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -39,7 +39,7 @@ func (p *PacketIO) ClientTLSHandshake(tlsConfig *tls.Config) error { conn := &tlsInternalConn{p.readWriter} tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { - return errors.WithStack(errors.Wrap(ErrHandshakeTLS, err)) + return p.wrapErr(errors.Wrap(ErrHandshakeTLS, err)) } p.readWriter = newTLSReadWriter(p.readWriter, tlsConn) return nil From 377e5fad401c5d6097dacbac6cd1b51501817ba9 Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Thu, 23 Nov 2023 15:51:23 +0800 Subject: [PATCH 2/4] remove user error --- pkg/manager/router/router.go | 5 - pkg/proxy/backend/authenticator.go | 45 ++++----- pkg/proxy/backend/authenticator_test.go | 28 +++--- pkg/proxy/backend/backend_conn_mgr.go | 21 ++-- pkg/proxy/backend/backend_conn_mgr_test.go | 18 ++-- pkg/proxy/backend/cmd_processor.go | 10 -- pkg/proxy/backend/cmd_processor_exec.go | 2 +- pkg/proxy/backend/cmd_processor_test.go | 2 +- pkg/proxy/backend/error.go | 112 ++++++++++++++------- pkg/proxy/backend/handshake_handler.go | 8 +- pkg/proxy/backend/testsuite_test.go | 2 +- pkg/proxy/net/compress.go | 9 +- pkg/proxy/net/error.go | 32 ------ pkg/proxy/net/mysql.go | 6 ++ pkg/proxy/net/packetio_mysql.go | 8 +- 15 files changed, 152 insertions(+), 156 deletions(-) diff --git a/pkg/manager/router/router.go b/pkg/manager/router/router.go index e77a70b8..28c9d7a4 100644 --- a/pkg/manager/router/router.go +++ b/pkg/manager/router/router.go @@ -7,11 +7,6 @@ import ( "time" glist "github.com/bahlo/generic-list-go" - "github.com/pingcap/tiproxy/lib/util/errors" -) - -var ( - ErrNoInstanceToSelect = errors.New("no instances to route") ) // ConnEventReceiver receives connection events. diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 6ae1ab3b..7abaff72 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -18,10 +18,6 @@ import ( "go.uber.org/zap" ) -var ( - ErrCapabilityNegotiation = errors.New("capability negotiation failed") -) - const unknownAuthPlugin = "auth_unknown_plugin" const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF @@ -76,10 +72,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili // The error cannot be sent to the client because the client only expects an initial handshake packet. // The only way is to log it and disconnect. logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps)) - return pnet.WrapUserError(errors.Wrap(ErrBackendHandshake, errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)), capabilityErrMsg) + return errors.Wrapf(ErrBackendCap, "require %s from backend", requiredBackendCaps^commonCaps) } if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) { - return pnet.WrapUserError(errors.Wrap(ErrBackendHandshake, errors.New("backend doesn't enable TLS")), requireTiDBTLSErrMsg) + return ErrBackendNoTLS } return nil } @@ -106,7 +102,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte frontendCapability := pnet.Capability(binary.LittleEndian.Uint32(pkt)) if isSSL { if _, err = clientIO.ServerTLSHandshake(frontendTLSConfig); err != nil { - return pnet.WrapUserError(err, err.Error()) + return errors.Wrap(ErrClientHandshake, err) } pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp() if err != nil { @@ -125,7 +121,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil { return writeErr } - return errors.Wrap(ErrClientHandshake, errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)) + return errors.Wrapf(ErrClientCap, "require %s from frontend", requiredFrontendCaps&^commonCaps) } commonCaps := frontendCapability & proxyCapability if frontendCapability^commonCaps != 0 { @@ -147,10 +143,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if errors.As(err, &warning) { logger.Warn("parse handshake response encounters error", zap.Error(err)) } else if err != nil { - return pnet.WrapUserError(err, parsePktErrMsg) + return errors.Wrap(ErrProxyMalformed, err) } if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil { - return pnet.WrapUserError(err, err.Error()) + return errors.Wrap(ErrProxyErr, err) } auth.user = clientResp.User auth.dbname = clientResp.DB @@ -163,25 +159,24 @@ RECONNECT: // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. backendIO, err := getBackendIO(cctx, auth, clientResp) if err != nil { - return pnet.WrapUserError(err, connectErrMsg) + return err } backendIO.ResetSequence() // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { - return pnet.WrapUserError(err, handshakeErrMsg) + return errors.Wrap(ErrBackendHandshake, err) } // read backend initial handshake serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO) if err != nil { - if IsMySQLError(err) { + if pnet.IsMySQLError(err) { if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil { - err = writeErr + return writeErr } - return err } - return pnet.WrapUserError(err, handshakeErrMsg) + return errors.Wrap(ErrBackendHandshake, err) } if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { @@ -207,7 +202,7 @@ RECONNECT: // Copy the auth data so that the backend can set correct `using password` in the error message. unknownAuthPlugin, clientResp.AuthData, 0, ); err != nil { - return pnet.WrapUserError(err, handshakeErrMsg) + return err } // forward other packets @@ -220,7 +215,7 @@ loop: // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) { - return pnet.WrapUserError(err, checkPPV2ErrMsg) + return ErrBackendPPV2 } return err } @@ -240,7 +235,7 @@ loop: return err } if packetErr != nil { - return packetErr + return errors.Wrap(ErrClientAuthFail, packetErr) } pktIdx++ @@ -278,12 +273,12 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) { func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error { if len(sessionToken) == 0 { - return errors.New("session token is empty") + return errors.Wrapf(ErrProxyErr, "session token is empty") } // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { - return err + return errors.Wrap(ErrBackendHandshake, err) } _, backendCapability, err := auth.readInitialHandshake(backendIO) @@ -305,7 +300,7 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac if err = auth.handleSecondAuthResult(backendIO); err == nil { return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel) } - return err + return errors.Wrap(ErrBackendHandshake, err) } func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) { @@ -348,7 +343,7 @@ func (auth *Authenticator) writeAuthHandshake( var enableTLS bool if auth.requireBackendTLS { if backendTLSConfig == nil { - return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg) + return ErrProxyNoTLS } enableTLS = true } else { @@ -372,7 +367,7 @@ func (auth *Authenticator) writeAuthHandshake( if err := backendIO.ClientTLSHandshake(tcfg); err != nil { // tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet // tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet - return pnet.WrapUserError(err, checkPPV2ErrMsg) + return errors.Wrap(ErrBackendPPV2, err) } } else { resp.Capability &= ^pnet.ClientSSL @@ -395,7 +390,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro case pnet.ErrHeader.Byte(): return pnet.ParseErrorPacket(data) default: // mysql.AuthSwitchRequest, ShaCommand: - return errors.Errorf("read unexpected command: %#x", data[0]) + return errors.Wrapf(mysql.ErrMalformPacket, "read unexpected command: %#x", data[0]) } } diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 47d3e99d..fcb0b69e 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tiproxy/lib/util/errors" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" "github.com/stretchr/testify/require" ) @@ -70,10 +69,10 @@ func TestUnsupportedCapability(t *testing.T) { for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) { - if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { - require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation) - } else if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { - require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation) + if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { + require.ErrorIs(t, ts.mp.err, ErrClientCap) + } else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { + require.ErrorIs(t, ts.mp.err, ErrBackendCap) } else { require.NoError(t, ts.mc.err) require.NoError(t, ts.mp.err) @@ -318,8 +317,8 @@ func TestAuthFail(t *testing.T) { func TestRequireBackendTLS(t *testing.T) { tests := []struct { - cfg cfgOverrider - errMsg string + cfg cfgOverrider + err error }{ { cfg: func(cfg *testConfig) { @@ -327,7 +326,7 @@ func TestRequireBackendTLS(t *testing.T) { cfg.proxyConfig.backendTLSConfig = nil cfg.backendConfig.capability |= pnet.ClientSSL }, - errMsg: requireProxyTLSErrMsg, + err: ErrProxyNoTLS, }, { cfg: func(cfg *testConfig) { @@ -335,7 +334,7 @@ func TestRequireBackendTLS(t *testing.T) { cfg.backendConfig.tlsConfig = nil cfg.backendConfig.capability &= ^pnet.ClientSSL }, - errMsg: requireTiDBTLSErrMsg, + err: ErrBackendNoTLS, }, { cfg: func(cfg *testConfig) { @@ -351,10 +350,8 @@ func TestRequireBackendTLS(t *testing.T) { for _, tt := range tests { ts, clean := newTestSuite(t, tc, tt.cfg) ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { - if len(tt.errMsg) > 0 { - var userError *pnet.UserError - require.True(t, errors.As(ts.mp.err, &userError)) - require.Equal(t, tt.errMsg, userError.UserMsg()) + if tt.err != nil { + require.ErrorIs(t, ts.mp.err, tt.err) } else { require.NoError(t, ts.mp.err) } @@ -401,9 +398,8 @@ func TestProxyProtocol(t *testing.T) { // TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable. // So when backend enables proxy-protocol and proxy disables it, it still works well. if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol { - var userError *pnet.UserError - require.True(t, errors.As(ts.mp.err, &userError)) - require.Equal(t, checkPPV2ErrMsg, userError.UserMsg()) + err := ErrToClient(ts.mp.err) + require.Equal(t, ErrBackendPPV2, err) } else { require.NoError(t, ts.mp.err) } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 85862c65..4cf7f5bc 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -170,8 +170,13 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.clientIO = clientIO err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) if err != nil { - mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err, Error2Source(err)) - clientIO.WriteUserError(err) + src := Error2Source(err) + mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err, src) + // For some errors, convert them to MySQL errors and send them to the client. + if clientErr := ErrToClient(err); clientErr != nil { + clientIO.WriteUserError(clientErr) + } + mgr.quitSource = src return err } mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), nil, SrcNone) @@ -189,7 +194,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { r, err := mgr.handshakeHandler.GetRouter(cctx, resp) if err != nil { - return nil, pnet.WrapUserError(err, err.Error()) + return nil, errors.Wrap(ErrProxyErr, err) } // Reasons to wait: // - The TiDB instances may not be initialized yet @@ -209,17 +214,17 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato addr, err = selector.Next() } if err != nil { - return nil, backoff.Permanent(pnet.WrapUserError(err, err.Error())) + return nil, backoff.Permanent(errors.Wrap(ErrProxyErr, err)) } if addr == "" { - return nil, router.ErrNoInstanceToSelect + return nil, ErrProxyNoBackend } var cn net.Conn cn, err = net.DialTimeout("tcp", addr, DialTimeout) selector.Finish(mgr, err == nil) if err != nil { - return nil, errors.Wrapf(err, "dial backend %s error", addr) + return nil, errors.Wrap(ErrBackendHandshake, errors.Wrapf(err, "dial backend %s error", addr)) } // NOTE: should use DNS name as much as possible @@ -282,7 +287,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( addCmdMetrics(cmd, mgr.ServerAddr(), startTime) } if err != nil { - if !IsMySQLError(err) { + if !pnet.IsMySQLError(err) { return } else { mgr.logger.Debug("got a mysql error", zap.Error(err), zap.Stringer("cmd", cmd)) @@ -318,7 +323,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( // Execute the held request no matter redirection succeeds or not. _, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO.Load(), false) addCmdMetrics(cmd, mgr.ServerAddr(), startTime) - if err != nil && !IsMySQLError(err) { + if err != nil && !pnet.IsMySQLError(err) { return } } else if mgr.closeStatus.Load() == statusNotifyClose { diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index f07f3329..9d702881 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -150,7 +150,7 @@ func (ts *backendMgrTester) redirectSucceed4Proxy(_, _ *pnet.PacketIO) error { ts.mp.Redirect(ts.tc.backendListener.Addr().String()) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(ts.t, eventSucceed) require.NotEqual(ts.t, backend1, ts.mp.backendIO.Load()) - require.Equal(ts.t, SrcClientQuit, ts.mp.QuitSource()) + require.Equal(ts.t, SrcNone, ts.mp.QuitSource()) return nil } @@ -372,7 +372,7 @@ func TestRedirectInTxn(t *testing.T) { require.NoError(t, err) ts.mp.getEventReceiver().(*mockEventReceiver).checkEvent(t, eventFail) require.Equal(t, backend1, ts.mp.backendIO.Load()) - require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) + require.Equal(t, SrcNone, ts.mp.QuitSource()) return nil }, backend: func(packetIO *pnet.PacketIO) error { @@ -411,7 +411,7 @@ func TestConnectFail(t *testing.T) { }, { proxy: func(clientIO, backendIO *pnet.PacketIO) error { - require.Equal(t, SrcClientErr, ts.mp.QuitSource()) + require.Equal(t, SrcClientAuthFail, ts.mp.QuitSource()) return nil }, }, @@ -624,7 +624,7 @@ func TestCustomHandshake(t *testing.T) { }, { proxy: func(clientIO, backendIO *pnet.PacketIO) error { - require.Equal(t, SrcClientQuit, ts.mp.QuitSource()) + require.Equal(t, SrcNone, ts.mp.QuitSource()) return nil }, }, @@ -767,8 +767,8 @@ func TestHandlerReturnError(t *testing.T) { return router.NewStaticRouter(nil), nil } }, - errMsg: connectErrMsg, - quitSource: SrcProxyErr, + errMsg: ErrProxyNoBackend.Error(), + quitSource: SrcProxyNoBackend, }, } for _, test := range tests { @@ -847,12 +847,12 @@ func TestGetBackendIO(t *testing.T) { getRouter: func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { return rt, nil }, - onHandshake: func(connContext ConnContext, s string, err error) { + onHandshake: func(connContext ConnContext, s string, err error, src ErrorSource) { if err != nil && len(s) > 0 { badAddrs[s] = struct{}{} } if err != nil { - require.Equal(t, SrcProxyErr, connContext.QuitSource()) + require.Equal(t, SrcBackendHandshake, src) } }, } @@ -961,7 +961,7 @@ func TestBackendInactive(t *testing.T) { }, { proxy: func(clientIO, backendIO *pnet.PacketIO) error { - require.Equal(t, SrcBackendQuit, ts.mp.QuitSource()) + require.Equal(t, SrcBackendNetwork, ts.mp.QuitSource()) return nil }, }, diff --git a/pkg/proxy/backend/cmd_processor.go b/pkg/proxy/backend/cmd_processor.go index 74fde4b4..2b4eb923 100644 --- a/pkg/proxy/backend/cmd_processor.go +++ b/pkg/proxy/backend/cmd_processor.go @@ -6,7 +6,6 @@ package backend import ( "encoding/binary" - gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tidb/parser/mysql" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" "go.uber.org/zap" @@ -116,12 +115,3 @@ func (cp *CmdProcessor) hasPendingPreparedStmts() bool { } return false } - -// IsMySQLError returns true if the error is a MySQL error. -func IsMySQLError(err error) bool { - if err == nil { - return false - } - _, ok := err.(*gomysql.MyError) - return ok -} diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 620bfd43..c82fb633 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -25,7 +25,7 @@ func (cp *CmdProcessor) executeCmd(request []byte, clientIO, backendIO *pnet.Pac var response []byte if _, response, err = cp.query(backendIO, "COMMIT"); err != nil { // If commit fails, forward the response to the client. - if IsMySQLError(err) { + if pnet.IsMySQLError(err) { if writeErr := clientIO.WritePacket(response, true); writeErr != nil { return false, writeErr } diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index d5050a72..26162351 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -1027,7 +1027,7 @@ func TestNetworkError(t *testing.T) { clientErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) require.True(t, pnet.IsDisconnectError(ts.mc.err)) - require.NotNil(t, ts.mp.err.(*pnet.UserError)) + require.NotNil(t, ErrToClient(ts.mp.err)) } backendErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 38e44fa6..e259f634 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -9,23 +9,55 @@ import ( pnet "github.com/pingcap/tiproxy/pkg/proxy/net" ) -const ( - connectErrMsg = "No available TiDB instances, please make sure TiDB is available" - parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP" - handshakeErrMsg = "TiProxy fails to connect to TiDB, please make sure TiDB is available" - capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" - requireProxyTLSErrMsg = "Require TLS enabled on TiProxy when require-backend-tls=true" - requireTiDBTLSErrMsg = "Require TLS enabled on TiDB when require-backend-tls=true" - checkPPV2ErrMsg = "TiProxy fails to connect to TiDB, please make sure TiDB proxy-protocol is set correctly. If this error still exists, please contact PingCAP" +// These errors may not be disconnection errors. They are used for marking whether the error comes from the client or the backend. +var ( + ErrClientConn = errors.New("this is an error from the client connection") + ErrBackendConn = errors.New("this is an error from the backend connection") ) +// These errors are used to track internal errors. var ( - ErrClientConn = errors.New("read or write client connection fail") - ErrClientHandshake = errors.New("handshake with client fail") - ErrBackendConn = errors.New("read or write backend connection fail") - ErrBackendHandshake = errors.New("handshake with backend fail") + ErrClientCap = errors.New("Verify client capability failed, please upgrade the client") + ErrClientHandshake = errors.New("Fails to handshake with the client") + ErrClientAuthFail = errors.New("Authentication fails") + ErrProxyMalformed = errors.New("TiProxy fails to parse the packet, please contact PingCAP") + ErrProxyErr = errors.New("Other serverless error") + ErrProxyNoBackend = errors.New("No available TiDB instances, please make sure TiDB is available") + ErrProxyNoTLS = errors.New("Require TLS enabled on TiProxy when require-backend-tls=true") + ErrBackendCap = errors.New("Verify TiDB capability failed, please upgrade TiDB") + ErrBackendHandshake = errors.New("TiProxy fails to connect to TiDB, please make sure TiDB is available") + ErrBackendNoTLS = errors.New("Require TLS enabled on TiDB when require-backend-tls=true") + ErrBackendPPV2 = errors.New("TiProxy fails to connect to TiDB, please make sure TiDB proxy-protocol is set correctly. If this error still exists, please contact PingCAP") ) +// ErrToClient returns the error that needs to be sent to the client. +func ErrToClient(err error) error { + switch { + case pnet.IsMySQLError(err): + // If it's a MySQL error, it should be already sent to the client. + return nil + case errors.Is(err, ErrProxyMalformed): + return ErrProxyMalformed + case errors.Is(err, ErrProxyNoBackend): + return ErrProxyNoBackend + case errors.Is(err, ErrProxyNoTLS): + return ErrProxyNoTLS + case errors.Is(err, ErrBackendCap): + return ErrBackendCap + case errors.Is(err, ErrBackendHandshake): + return ErrBackendHandshake + case errors.Is(err, ErrBackendNoTLS): + return ErrBackendNoTLS + case errors.Is(err, ErrBackendPPV2): + return ErrBackendPPV2 + case errors.Is(err, ErrProxyErr): + // The error is returned by HandshakeHandler/BackendFetcher and wrapped with ErrProxyErr. + return errors.Unwrap(err) + } + // For other errors, we don't send them to the client. + return nil +} + type SourceComp int const ( @@ -40,23 +72,25 @@ type ErrorSource int const ( // SrcNone includes: succeed for OnHandshake; client normally quit for OnConnClose SrcNone ErrorSource = iota - // SrcClientNetwork includes: EOF; reset by peer; connection refused; TLS handshake fails + // SrcClientNetwork includes: EOF; reset by peer; connection refused SrcClientNetwork - // SrcClientHandshake includes: client capability unsupported + // SrcClientHandshake includes: client capability unsupported; TLS handshake fails SrcClientHandshake - // SrcClientSQLErr includes: backend returns auth fail; SQL error + // SrcClientAuthFail includes: backend returns auth fail + SrcClientAuthFail + // SrcClientSQLErr includes: SQL error SrcClientSQLErr // SrcProxyQuit includes: proxy graceful shutdown SrcProxyQuit // SrcProxyMalformed includes: malformed packet; invalid sequence SrcProxyMalformed - // SrcProxyGetBackend includes: no backends - SrcProxyGetBackend - // SrcProxyErr includes: HandshakeHandler returns error; proxy disables TLS + // SrcProxyNoBackend includes: no backends + SrcProxyNoBackend + // SrcProxyErr includes: HandshakeHandler returns error; proxy disables TLS; unexpected errors SrcProxyErr - // SrcBackendNetwork includes: EOF; reset by peer; connection refused; TLS handshake fails + // SrcBackendNetwork includes: EOF; reset by peer; connection refused SrcBackendNetwork - // SrcBackendHandshake includes: backend capability unsupported; backend disables TLS + // SrcBackendHandshake includes: dial failure; backend capability unsupported; backend disables TLS; TLS handshake fails SrcBackendHandshake ) @@ -65,41 +99,51 @@ func Error2Source(err error) ErrorSource { switch { case err == nil: return SrcNone - case errors.Is(err, pnet.ErrInvalidSequence) || errors.Is(err, gomysql.ErrMalformPacket): - // We assume the clients and TiDB are right and treat it as TiProxy bugs. - // ErrInvalidSequence may be wrapped with ErrClientConn or ErrBackendConn, so put it before other conditions. - return SrcProxyErr - case errors.Is(err, ErrClientConn): + // Disconnection errors may come from other errors such as ErrProxyNoBackend. + // ErrClientConn and ErrBackendConn may include non-connection errors. + case pnet.IsDisconnectError(err) && errors.Is(err, ErrClientConn): return SrcClientNetwork - case errors.Is(err, ErrBackendConn): + case pnet.IsDisconnectError(err) && errors.Is(err, ErrBackendConn): return SrcBackendNetwork + case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, gomysql.ErrMalformPacket): + // We assume the clients and TiDB are right and treat it as TiProxy bugs. + return SrcProxyMalformed case errors.Is(err, ErrClientHandshake): return SrcClientHandshake + case errors.Is(err, ErrClientAuthFail): + return SrcClientAuthFail case errors.Is(err, ErrBackendHandshake): return SrcBackendHandshake - case IsMySQLError(err): + case errors.Is(err, ErrProxyNoBackend): + return SrcProxyNoBackend + case pnet.IsMySQLError(err): + // ErrClientAuthFail and ErrBackendHandshake may also contain MySQL error. return SrcClientSQLErr default: + // All other untracked errors are proxy errors. return SrcProxyErr } } +// String is used for metrics labels and log. func (es ErrorSource) String() string { switch es { case SrcNone: - return "ok" + return "success" case SrcClientNetwork: return "client network break" case SrcClientHandshake: return "client handshake fail" + case SrcClientAuthFail: + return "auth fail" case SrcClientSQLErr: - return "client SQL error" + return "SQL error" case SrcProxyQuit: return "proxy shutdown" case SrcProxyMalformed: return "malformed packet" - case SrcProxyGetBackend: - return "proxy get backend fail" + case SrcProxyNoBackend: + return "get backend fail" case SrcProxyErr: return "proxy error" case SrcBackendNetwork: @@ -113,9 +157,9 @@ func (es ErrorSource) String() string { // GetSourceComp returns which component does this error belong to. func (es ErrorSource) GetSourceComp() SourceComp { switch es { - case SrcClientNetwork, SrcClientHandshake, SrcClientSQLErr: + case SrcClientNetwork, SrcClientHandshake, SrcClientAuthFail, SrcClientSQLErr: return CompClient - case SrcProxyQuit, SrcProxyMalformed, SrcProxyGetBackend, SrcProxyErr: + case SrcProxyQuit, SrcProxyMalformed, SrcProxyNoBackend, SrcProxyErr: return CompProxy case SrcBackendNetwork, SrcBackendHandshake: return CompBackend @@ -127,7 +171,7 @@ func (es ErrorSource) GetSourceComp() SourceComp { // Normal returns whether this error source is expected. func (es ErrorSource) Normal() bool { switch es { - case SrcNone, SrcProxyQuit: + case SrcNone, SrcProxyQuit, SrcClientSQLErr: return true } return false diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 9cc8e4e0..d400540d 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -106,9 +106,9 @@ func (handler *DefaultHandshakeHandler) GetServerVersion() string { type CustomHandshakeHandler struct { getRouter func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) - onHandshake func(ConnContext, string, error) + onHandshake func(ConnContext, string, error, ErrorSource) onTraffic func(ConnContext) - onConnClose func(ConnContext) error + onConnClose func(ConnContext, ErrorSource) error handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error handleHandshakeErr func(ctx ConnContext, err *gomysql.MyError) bool getCapability func() pnet.Capability @@ -124,7 +124,7 @@ func (h *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.Handshake func (h *CustomHandshakeHandler) OnHandshake(ctx ConnContext, addr string, err error, src ErrorSource) { if h.onHandshake != nil { - h.onHandshake(ctx, addr, err) + h.onHandshake(ctx, addr, err, src) } } @@ -136,7 +136,7 @@ func (h *CustomHandshakeHandler) OnTraffic(ctx ConnContext) { func (h *CustomHandshakeHandler) OnConnClose(ctx ConnContext, src ErrorSource) error { if h.onConnClose != nil { - return h.onConnClose(ctx) + return h.onConnClose(ctx, src) } return nil } diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index f323e525..71184d4a 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -162,7 +162,7 @@ func (ts *testSuite) runAndCheck(t *testing.T, c checker, clientRunner, backendR require.NoError(t, ts.mc.err) require.NoError(t, ts.mb.err) if ts.mb.err != nil { - require.True(t, IsMySQLError(ts.mp.err)) + require.True(t, pnet.IsMySQLError(ts.mp.err)) } if clientRunner != nil && backendRunner != nil { // Ensure all the packets are forwarded. diff --git a/pkg/proxy/net/compress.go b/pkg/proxy/net/compress.go index 2f4dd407..fb6a9be5 100644 --- a/pkg/proxy/net/compress.go +++ b/pkg/proxy/net/compress.go @@ -8,6 +8,7 @@ import ( "compress/zlib" "io" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/klauspost/compress/zstd" "github.com/pingcap/tiproxy/lib/util/errors" "go.uber.org/zap" @@ -45,7 +46,7 @@ func (p *PacketIO) SetCompressionAlgorithm(algorithm CompressAlgorithm, zstdLeve p.readWriter = newCompressedReadWriter(p.readWriter, algorithm, zstdLevel, p.logger) case CompressionNone: default: - return errors.Errorf("Unknown compression algorithm %d", algorithm) + return errors.Wrapf(mysql.ErrMalformPacket, "Unknown compression algorithm %d", algorithm) } return nil } @@ -272,7 +273,7 @@ func (crw *compressedReadWriter) compress(data []byte) ([]byte, error) { compressWriter, err = zstd.NewWriter(&compressedPacket, zstd.WithEncoderLevel(crw.zstdLevel)) } if err != nil { - return nil, errors.WithStack(err) + return nil, errors.WithStack(errors.Wrap(mysql.ErrMalformPacket, err)) } if _, err = compressWriter.Write(data); err != nil { return nil, errors.WithStack(err) @@ -289,12 +290,12 @@ func (crw *compressedReadWriter) uncompress(data []byte, uncompressedLength int) switch crw.algorithm { case CompressionZlib: if compressedReader, err = zlib.NewReader(bytes.NewReader(data)); err != nil { - return errors.WithStack(err) + return errors.WithStack(errors.Wrap(mysql.ErrMalformPacket, err)) } case CompressionZstd: var decoder *zstd.Decoder if decoder, err = zstd.NewReader(bytes.NewReader(data)); err != nil { - return errors.WithStack(err) + return errors.WithStack(errors.Wrap(mysql.ErrMalformPacket, err)) } compressedReader = decoder.IOReadCloser() } diff --git a/pkg/proxy/net/error.go b/pkg/proxy/net/error.go index 00409b32..9c7f290d 100644 --- a/pkg/proxy/net/error.go +++ b/pkg/proxy/net/error.go @@ -15,35 +15,3 @@ var ( ErrCloseConn = errors.New("failed to close the connection") ErrHandshakeTLS = errors.New("failed to complete tls handshake") ) - -// UserError is returned to the client. -// err is used to log and userMsg is used to report to the user. -type UserError struct { - err error - userMsg string -} - -func WrapUserError(err error, userMsg string) *UserError { - if err == nil { - return nil - } - if ue, ok := err.(*UserError); ok { - return ue - } - return &UserError{ - err: err, - userMsg: userMsg, - } -} - -func (ue *UserError) UserMsg() string { - return ue.userMsg -} - -func (ue *UserError) Unwrap() error { - return ue.err -} - -func (ue *UserError) Error() string { - return ue.err.Error() -} diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index 74a98ccd..cca298c6 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -433,6 +433,12 @@ func IsErrorPacket(firstByte byte) bool { return firstByte == ErrHeader.Byte() } +// IsMySQLError returns true if the error is a MySQL error. +func IsMySQLError(err error) bool { + var myerr *gomysql.MyError + return errors.As(err, &myerr) +} + // The connection attribute names that are logged. // https://dev.mysql.com/doc/mysql-perfschema-excerpt/8.2/en/performance-schema-connection-attribute-tables.html const ( diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index f17ba3e3..bbf60f5a 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -85,7 +85,7 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err if len(pkt) < 32 { p.logger.Error("got malformed handshake response", zap.ByteString("packetData", pkt)) - err = WrapUserError(mysql.ErrMalformPacket, mysql.ErrMalformPacket.Error()) + err = mysql.ErrMalformPacket return } @@ -132,11 +132,7 @@ func (p *PacketIO) WriteUserError(err error) { if err == nil { return } - var ue *UserError - if !errors.As(err, &ue) { - return - } - myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()) + myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, err.Error()) if writeErr := p.WriteErrPacket(myErr); writeErr != nil { p.logger.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr)) } From 3c4ad0a31e772cf1ceeee5a0467e0126620568dd Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Thu, 30 Nov 2023 20:37:34 +0800 Subject: [PATCH 3/4] add tests --- pkg/proxy/backend/authenticator.go | 30 +++++++++++++--------- pkg/proxy/backend/authenticator_test.go | 10 ++++++++ pkg/proxy/backend/backend_conn_mgr.go | 2 +- pkg/proxy/backend/cmd_processor_test.go | 11 +++++--- pkg/proxy/backend/common_test.go | 4 +-- pkg/proxy/backend/error.go | 34 +++++++++++++------------ pkg/proxy/backend/testsuite_test.go | 3 +++ pkg/proxy/net/mysql_test.go | 11 ++++++++ pkg/proxy/net/net_err.go | 5 +++- 9 files changed, 75 insertions(+), 35 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 7abaff72..98996160 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -143,7 +143,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if errors.As(err, &warning) { logger.Warn("parse handshake response encounters error", zap.Error(err)) } else if err != nil { - return errors.Wrap(ErrProxyMalformed, err) + return err } if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil { return errors.Wrap(ErrProxyErr, err) @@ -165,7 +165,7 @@ RECONNECT: // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { - return errors.Wrap(ErrBackendHandshake, err) + return err } // read backend initial handshake @@ -176,7 +176,7 @@ RECONNECT: return writeErr } } - return errors.Wrap(ErrBackendHandshake, err) + return err } if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { @@ -215,7 +215,7 @@ loop: // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) { - return ErrBackendPPV2 + return errors.Wrap(ErrBackendPPV2, err) } return err } @@ -242,10 +242,10 @@ loop: switch serverPkt[0] { case pnet.OKHeader.Byte(): if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil { - return err + return errors.Wrap(ErrClientHandshake, err) } if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil { - return err + return errors.Wrap(ErrBackendHandshake, err) } return nil default: // mysql.AuthSwitchRequest, ShaCommand @@ -273,12 +273,12 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) { func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error { if len(sessionToken) == 0 { - return errors.Wrapf(ErrProxyErr, "session token is empty") + return errors.Wrapf(ErrBackendHandshake, "session token is empty") } // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { - return errors.Wrap(ErrBackendHandshake, err) + return err } _, backendCapability, err := auth.readInitialHandshake(backendIO) @@ -298,17 +298,20 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac } if err = auth.handleSecondAuthResult(backendIO); err == nil { - return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel) + if err = setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil { + return errors.Wrap(ErrBackendHandshake, err) + } } return errors.Wrap(ErrBackendHandshake, err) } func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) { if serverPkt, err = backendIO.ReadPacket(); err != nil { + err = errors.Wrap(ErrBackendHandshake, err) return } if pnet.IsErrorPacket(serverPkt[0]) { - err = pnet.ParseErrorPacket(serverPkt) + err = errors.Wrap(ErrBackendHandshake, pnet.ParseErrorPacket(serverPkt)) return } capability, _, _ = pnet.ParseInitialHandshake(serverPkt) @@ -355,7 +358,7 @@ func (auth *Authenticator) writeAuthHandshake( pkt = pnet.MakeHandshakeResponse(resp) // write SSL Packet if err := backendIO.WritePacket(pkt[:32], true); err != nil { - return err + return errors.Wrap(ErrBackendHandshake, err) } // Send TLS / SSL request packet. The server must have supported TLS. tcfg := backendTLSConfig.Clone() @@ -375,7 +378,10 @@ func (auth *Authenticator) writeAuthHandshake( } // write handshake resp - return backendIO.WritePacket(pkt, true) + if err := backendIO.WritePacket(pkt, true); err != nil { + return errors.Wrap(ErrBackendHandshake, err) + } + return nil } func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error { diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index fcb0b69e..221068d0 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -71,8 +71,12 @@ func TestUnsupportedCapability(t *testing.T) { ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) { if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { require.ErrorIs(t, ts.mp.err, ErrClientCap) + require.Nil(t, ErrToClient(ts.mp.err)) + require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err)) } else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { require.ErrorIs(t, ts.mp.err, ErrBackendCap) + require.Equal(t, ErrBackendCap, ErrToClient(ts.mp.err)) + require.Equal(t, SrcBackendHandshake, Error2Source(ts.mp.err)) } else { require.NoError(t, ts.mc.err) require.NoError(t, ts.mp.err) @@ -310,6 +314,7 @@ func TestAuthFail(t *testing.T) { ts, clean := newTestSuite(t, tc, cfg) ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { require.Equal(t, len(ts.mc.authData), len(ts.mb.authData)) + require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err)) }) clean() } @@ -319,6 +324,7 @@ func TestRequireBackendTLS(t *testing.T) { tests := []struct { cfg cfgOverrider err error + src ErrorSource }{ { cfg: func(cfg *testConfig) { @@ -327,6 +333,7 @@ func TestRequireBackendTLS(t *testing.T) { cfg.backendConfig.capability |= pnet.ClientSSL }, err: ErrProxyNoTLS, + src: SrcProxyErr, }, { cfg: func(cfg *testConfig) { @@ -335,6 +342,7 @@ func TestRequireBackendTLS(t *testing.T) { cfg.backendConfig.capability &= ^pnet.ClientSSL }, err: ErrBackendNoTLS, + src: SrcBackendHandshake, }, { cfg: func(cfg *testConfig) { @@ -352,6 +360,7 @@ func TestRequireBackendTLS(t *testing.T) { ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { if tt.err != nil { require.ErrorIs(t, ts.mp.err, tt.err) + require.Equal(t, tt.src, Error2Source(ts.mp.err)) } else { require.NoError(t, ts.mp.err) } @@ -400,6 +409,7 @@ func TestProxyProtocol(t *testing.T) { if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol { err := ErrToClient(ts.mp.err) require.Equal(t, ErrBackendPPV2, err) + require.Equal(t, SrcBackendHandshake, Error2Source(err)) } else { require.NoError(t, ts.mp.err) } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 4cf7f5bc..e0dc05e8 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -307,7 +307,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( mgr.authenticator.capability &^= pnet.ClientMultiStatements mgr.cmdProcessor.capability &^= pnet.ClientMultiStatements default: - err = errors.Errorf("unrecognized set_option value:%d", val) + err = errors.Wrapf(gomysql.ErrMalformPacket, "unrecognized set_option value:%d", val) return } case pnet.ComChangeUser: diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index 26162351..9b09c2d1 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -165,6 +165,7 @@ func TestDirectQuery(t *testing.T) { }, c: func(t *testing.T, ts *testSuite) { require.Error(t, ts.mp.err) + require.Equal(t, SrcClientSQLErr, Error2Source(ts.mp.err)) require.NoError(t, ts.mb.err) }, }, @@ -1027,7 +1028,6 @@ func TestNetworkError(t *testing.T) { clientErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) require.True(t, pnet.IsDisconnectError(ts.mc.err)) - require.NotNil(t, ErrToClient(ts.mp.err)) } backendErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) @@ -1039,10 +1039,13 @@ func TestNetworkError(t *testing.T) { ts, clean := newTestSuite(t, tc, clientExitCfg) ts.authenticateFirstTime(t, backendErrChecker) + require.Equal(t, SrcClientNetwork, Error2Source(ts.mp.err)) clean() ts, clean = newTestSuite(t, tc, backendExitCfg) ts.authenticateFirstTime(t, clientErrChecker) + require.Equal(t, ErrBackendHandshake, ErrToClient(ts.mp.err)) + require.Equal(t, SrcBackendNetwork, Error2Source(ts.mp.err)) clean() ts, clean = newTestSuite(t, tc, backendExitCfg) @@ -1051,10 +1054,12 @@ func TestNetworkError(t *testing.T) { ts, clean = newTestSuite(t, tc, clientExitCfg) ts.executeCmd(t, backendErrChecker) + require.Equal(t, SrcClientNetwork, Error2Source(ts.mp.err)) clean() - ts, clean = newTestSuite(t, tc, clientExitCfg) - ts.executeCmd(t, backendErrChecker) + ts, clean = newTestSuite(t, tc, backendExitCfg) + ts.executeCmd(t, clientErrChecker) + require.Equal(t, SrcBackendNetwork, Error2Source(ts.mp.err)) clean() ts, clean = newTestSuite(t, tc, backendExitCfg) diff --git a/pkg/proxy/backend/common_test.go b/pkg/proxy/backend/common_test.go index 16056bad..bf0fd8ab 100644 --- a/pkg/proxy/backend/common_test.go +++ b/pkg/proxy/backend/common_test.go @@ -59,11 +59,11 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() { if !enableRoute { backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String()) require.NoError(t, err) - tc.proxyBIO = pnet.NewPacketIO(backendConn, lg, pnet.DefaultConnBufferSize) + tc.proxyBIO = pnet.NewPacketIO(backendConn, lg, pnet.DefaultConnBufferSize, pnet.WithWrapError(ErrBackendConn)) } clientConn, err := tc.proxyListener.Accept() require.NoError(t, err) - tc.proxyCIO = pnet.NewPacketIO(clientConn, lg, pnet.DefaultConnBufferSize) + tc.proxyCIO = pnet.NewPacketIO(clientConn, lg, pnet.DefaultConnBufferSize, pnet.WithWrapError(ErrClientConn)) }) wg.Run(func() { conn, err := net.Dial("tcp", tc.proxyListener.Addr().String()) diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index e259f634..666d1199 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -20,7 +20,6 @@ var ( ErrClientCap = errors.New("Verify client capability failed, please upgrade the client") ErrClientHandshake = errors.New("Fails to handshake with the client") ErrClientAuthFail = errors.New("Authentication fails") - ErrProxyMalformed = errors.New("TiProxy fails to parse the packet, please contact PingCAP") ErrProxyErr = errors.New("Other serverless error") ErrProxyNoBackend = errors.New("No available TiDB instances, please make sure TiDB is available") ErrProxyNoTLS = errors.New("Require TLS enabled on TiProxy when require-backend-tls=true") @@ -36,8 +35,6 @@ func ErrToClient(err error) error { case pnet.IsMySQLError(err): // If it's a MySQL error, it should be already sent to the client. return nil - case errors.Is(err, ErrProxyMalformed): - return ErrProxyMalformed case errors.Is(err, ErrProxyNoBackend): return ErrProxyNoBackend case errors.Is(err, ErrProxyNoTLS): @@ -72,7 +69,7 @@ type ErrorSource int const ( // SrcNone includes: succeed for OnHandshake; client normally quit for OnConnClose SrcNone ErrorSource = iota - // SrcClientNetwork includes: EOF; reset by peer; connection refused + // SrcClientNetwork includes: EOF; reset by peer; connection refused; io timeout SrcClientNetwork // SrcClientHandshake includes: client capability unsupported; TLS handshake fails SrcClientHandshake @@ -88,31 +85,36 @@ const ( SrcProxyNoBackend // SrcProxyErr includes: HandshakeHandler returns error; proxy disables TLS; unexpected errors SrcProxyErr - // SrcBackendNetwork includes: EOF; reset by peer; connection refused + // SrcBackendNetwork includes: EOF; reset by peer; connection refused; io timeout SrcBackendNetwork - // SrcBackendHandshake includes: dial failure; backend capability unsupported; backend disables TLS; TLS handshake fails + // SrcBackendHandshake includes: dial failure; backend capability unsupported; backend disables TLS; TLS handshake fails; proxy protocol fails SrcBackendHandshake ) // Error2Source returns the ErrorSource by the error. func Error2Source(err error) ErrorSource { - switch { - case err == nil: + if err == nil { return SrcNone - // Disconnection errors may come from other errors such as ErrProxyNoBackend. + } + // Disconnection errors may come from other errors such as ErrProxyNoBackend and ErrBackendHandshake. // ErrClientConn and ErrBackendConn may include non-connection errors. - case pnet.IsDisconnectError(err) && errors.Is(err, ErrClientConn): - return SrcClientNetwork - case pnet.IsDisconnectError(err) && errors.Is(err, ErrBackendConn): - return SrcBackendNetwork + if pnet.IsDisconnectError(err) { + if errors.Is(err, ErrClientConn) { + return SrcClientNetwork + } else if errors.Is(err, ErrBackendConn) { + return SrcBackendNetwork + } + } + switch { + // ErrInvalidSequence and ErrMalformPacket may be wrapped with other errors such as ErrBackendHandshake. case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, gomysql.ErrMalformPacket): // We assume the clients and TiDB are right and treat it as TiProxy bugs. return SrcProxyMalformed - case errors.Is(err, ErrClientHandshake): + case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap): return SrcClientHandshake case errors.Is(err, ErrClientAuthFail): return SrcClientAuthFail - case errors.Is(err, ErrBackendHandshake): + case errors.Is(err, ErrBackendHandshake), errors.Is(err, ErrBackendCap), errors.Is(err, ErrBackendNoTLS), errors.Is(err, ErrBackendPPV2): return SrcBackendHandshake case errors.Is(err, ErrProxyNoBackend): return SrcProxyNoBackend @@ -171,7 +173,7 @@ func (es ErrorSource) GetSourceComp() SourceComp { // Normal returns whether this error source is expected. func (es ErrorSource) Normal() bool { switch es { - case SrcNone, SrcProxyQuit, SrcClientSQLErr: + case SrcNone, SrcProxyQuit, SrcClientNetwork, SrcClientSQLErr: return true } return false diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 71184d4a..303a34f7 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -190,6 +190,9 @@ func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) { if ts.mc.capability&pnet.ClientConnectAttrs > 0 { require.Equal(t, ts.mc.attrs, ts.mb.attrs) } + if !ts.mb.authSucceed { + require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err)) + } } } diff --git a/pkg/proxy/net/mysql_test.go b/pkg/proxy/net/mysql_test.go index d777b62a..d951d1ab 100644 --- a/pkg/proxy/net/mysql_test.go +++ b/pkg/proxy/net/mysql_test.go @@ -6,6 +6,8 @@ package net import ( "testing" + gomysql "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/lib/util/logger" "github.com/stretchr/testify/require" ) @@ -61,3 +63,12 @@ func TestLogAttrs(t *testing.T) { require.Contains(t, str, `"client_name": "libmysql"`) require.Contains(t, str, `"program_name": "mysql"`) } + +func TestMySQLError(t *testing.T) { + myerr := &gomysql.MyError{} + require.True(t, IsMySQLError(errors.Wrap(ErrHandshakeTLS, myerr))) + require.False(t, IsMySQLError(errors.Wrap(myerr, ErrHandshakeTLS))) + require.False(t, IsMySQLError(ErrHandshakeTLS)) + require.True(t, errors.Is(errors.Wrap(ErrHandshakeTLS, myerr), ErrHandshakeTLS)) + require.True(t, errors.Is(errors.Wrap(myerr, ErrHandshakeTLS), ErrHandshakeTLS)) +} diff --git a/pkg/proxy/net/net_err.go b/pkg/proxy/net/net_err.go index b1659356..157c7f05 100644 --- a/pkg/proxy/net/net_err.go +++ b/pkg/proxy/net/net_err.go @@ -4,7 +4,9 @@ package net import ( + "context" "io" + "os" "syscall" "github.com/pingcap/tiproxy/lib/util/errors" @@ -13,7 +15,8 @@ import ( // IsDisconnectError returns whether the error is caused by peer disconnection. func IsDisconnectError(err error) bool { switch { - case errors.Is(err, io.EOF), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET): + case errors.Is(err, io.EOF), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET), + errors.Is(err, os.ErrDeadlineExceeded), errors.Is(err, context.DeadlineExceeded): return true } return false From 151cf1ed07d2c7455099ca49e291a48265699d47 Mon Sep 17 00:00:00 2001 From: xhe Date: Sat, 18 Nov 2023 15:32:39 +0800 Subject: [PATCH 4/4] conf: use server-http-tls to specify http security | tidb-test=pr/2248 (#403) Signed-off-by: xhe --- conf/proxy.toml | 12 +++++---- lib/config/proxy.go | 13 +++++----- lib/config/proxy_test.go | 4 +-- pkg/manager/cert/manager.go | 44 ++++++++++++++++---------------- pkg/manager/cert/manager_test.go | 32 +++++++++++++---------- pkg/proxy/proxy.go | 2 +- pkg/server/api/config_test.go | 4 +-- pkg/server/api/server.go | 2 +- 8 files changed, 61 insertions(+), 52 deletions(-) diff --git a/conf/proxy.toml b/conf/proxy.toml index 5507d558..b688e429 100644 --- a/conf/proxy.toml +++ b/conf/proxy.toml @@ -87,9 +87,6 @@ graceful-close-conn-timeout = 15 # server object: # 1. requires: cert/key or auto-certs(generate a temporary cert, mostly for testing) # 2. optionally: ca will enable server-side client verification. If skip-ca is true with non-empty ca, server will only verify clients if it can provide any cert. Otherwise, clients must provide a cert. -# peer object: -# 1. requires: cert/key/ca or auto-certs -# 2. useless/forbid: skip-ca # client object [security.cluster-tls] @@ -98,12 +95,17 @@ graceful-close-conn-timeout = 15 # client object [security.sql-tls] - # access to other components like TiDB or PD, will use this + # access to TiDB SQL(4000) port will use this skip-ca = true # server object [security.server-tls] - # proxy SQL or HTTP port will use this + # proxy SQL port will use this + # auto-certs = true + + # server object + [security.server-http-tls] + # proxy HTTP port will use this # auto-certs = true [metrics] diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 556173bc..ea8e221a 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -113,10 +113,10 @@ func (c TLSConfig) HasCA() bool { } type Security struct { - ServerTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"` - PeerTLS TLSConfig `yaml:"peer-tls,omitempty" toml:"peer-tls,omitempty" json:"peer-tls,omitempty"` - ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"` - SQLTLS TLSConfig `yaml:"sql-tls,omitempty" toml:"sql-tls,omitempty" json:"sql-tls,omitempty"` + ServerSQLTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"` + ServerHTTPTLS TLSConfig `yaml:"server-http-tls,omitempty" toml:"server-http-tls,omitempty" json:"server-http-tls,omitempty"` + ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"` + SQLTLS TLSConfig `yaml:"sql-tls,omitempty" toml:"sql-tls,omitempty" json:"sql-tls,omitempty"` } func DefaultKeepAlive() (frontend, backendHealthy, backendUnhealthy KeepAlive) { @@ -153,8 +153,8 @@ func NewConfig() *Config { cfg.Advance.IgnoreWrongNamespace = true cfg.Security.SQLTLS.MinTLSVersion = "1.1" - cfg.Security.PeerTLS.MinTLSVersion = "1.1" - cfg.Security.ServerTLS.MinTLSVersion = "1.1" + cfg.Security.ServerSQLTLS.MinTLSVersion = "1.1" + cfg.Security.ServerHTTPTLS.MinTLSVersion = "1.1" cfg.Security.ClusterTLS.MinTLSVersion = "1.1" return &cfg @@ -184,6 +184,7 @@ func (cfg *Config) Check() error { if cfg.Proxy.ConnBufferSize > 0 && (cfg.Proxy.ConnBufferSize > 16*1024*1024 || cfg.Proxy.ConnBufferSize < 1024) { return errors.Wrapf(ErrInvalidConfigValue, "conn-buffer-size must be between 1K and 16M") } + return nil } diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index b62ab628..69d822b8 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -49,13 +49,13 @@ var testProxyConfig = Config{ }, }, Security: Security{ - ServerTLS: TLSConfig{ + ServerSQLTLS: TLSConfig{ CA: "a", Cert: "b", Key: "c", AutoCerts: true, }, - PeerTLS: TLSConfig{ + ServerHTTPTLS: TLSConfig{ CA: "a", Cert: "b", Key: "c", diff --git a/pkg/manager/cert/manager.go b/pkg/manager/cert/manager.go index 1acadba3..d741401c 100644 --- a/pkg/manager/cert/manager.go +++ b/pkg/manager/cert/manager.go @@ -25,14 +25,14 @@ const ( // Currently, all the namespaces share the same certs but there might be per-namespace // certs in the future. type CertManager struct { - serverTLS *security.CertInfo // client / proxyctl -> proxy - serverTLSConfig atomic.Pointer[tls.Config] - peerTLS *security.CertInfo // proxy -> proxy - peerTLSConfig atomic.Pointer[tls.Config] - clusterTLS *security.CertInfo // proxy -> pd / tidb status port - clusterTLSConfig atomic.Pointer[tls.Config] - sqlTLS *security.CertInfo // proxy -> tidb sql port - sqlTLSConfig atomic.Pointer[tls.Config] + serverSQLTLS *security.CertInfo // client -> proxy + serverSQLTLSConfig atomic.Pointer[tls.Config] + serverHTTPTLS *security.CertInfo // proxyctl -> proxy + serverHTTPTLSConfig atomic.Pointer[tls.Config] + clusterTLS *security.CertInfo // proxy -> pd / tidb status port + clusterTLSConfig atomic.Pointer[tls.Config] + sqlTLS *security.CertInfo // proxy -> tidb sql port + sqlTLSConfig atomic.Pointer[tls.Config] cancel context.CancelFunc wg waitgroup.WaitGroup @@ -51,8 +51,8 @@ func NewCertManager() *CertManager { // cfgch can be set to nil for the serverless tier because it has no config manager. func (cm *CertManager) Init(cfg *config.Config, logger *zap.Logger, cfgch <-chan *config.Config) error { cm.logger = logger - cm.serverTLS = security.NewCert(true) - cm.peerTLS = security.NewCert(false) + cm.serverSQLTLS = security.NewCert(true) + cm.serverHTTPTLS = security.NewCert(true) cm.clusterTLS = security.NewCert(false) cm.sqlTLS = security.NewCert(false) cm.setConfig(cfg) @@ -67,8 +67,8 @@ func (cm *CertManager) Init(cfg *config.Config, logger *zap.Logger, cfgch <-chan } func (cm *CertManager) setConfig(cfg *config.Config) { - cm.serverTLS.SetConfig(cfg.Security.ServerTLS) - cm.peerTLS.SetConfig(cfg.Security.PeerTLS) + cm.serverSQLTLS.SetConfig(cfg.Security.ServerSQLTLS) + cm.serverHTTPTLS.SetConfig(cfg.Security.ServerHTTPTLS) cm.clusterTLS.SetConfig(cfg.Security.ClusterTLS) cm.sqlTLS.SetConfig(cfg.Security.SQLTLS) } @@ -77,16 +77,16 @@ func (cm *CertManager) SetRetryInterval(interval time.Duration) { cm.retryInterval.Store(int64(interval)) } -func (cm *CertManager) ServerTLS() *tls.Config { - return cm.serverTLSConfig.Load() +func (cm *CertManager) ServerSQLTLS() *tls.Config { + return cm.serverSQLTLSConfig.Load() } -func (cm *CertManager) ClusterTLS() *tls.Config { - return cm.clusterTLSConfig.Load() +func (cm *CertManager) ServerHTTPTLS() *tls.Config { + return cm.serverHTTPTLSConfig.Load() } -func (cm *CertManager) PeerTLS() *tls.Config { - return cm.peerTLSConfig.Load() +func (cm *CertManager) ClusterTLS() *tls.Config { + return cm.clusterTLSConfig.Load() } func (cm *CertManager) SQLTLS() *tls.Config { @@ -122,15 +122,15 @@ func (cm *CertManager) reloadLoop(ctx context.Context, cfgch <-chan *config.Conf // If any error happens, we still continue and use the old cert. func (cm *CertManager) reload() error { errs := make([]error, 0, 4) - if tlsConfig, err := cm.serverTLS.Reload(cm.logger); err != nil { + if tlsConfig, err := cm.serverSQLTLS.Reload(cm.logger); err != nil { errs = append(errs, err) } else { - cm.serverTLSConfig.Store(tlsConfig) + cm.serverSQLTLSConfig.Store(tlsConfig) } - if tlsConfig, err := cm.peerTLS.Reload(cm.logger); err != nil { + if tlsConfig, err := cm.serverHTTPTLS.Reload(cm.logger); err != nil { errs = append(errs, err) } else { - cm.peerTLSConfig.Store(tlsConfig) + cm.serverHTTPTLSConfig.Store(tlsConfig) } if tlsConfig, err := cm.clusterTLS.Reload(cm.logger); err != nil { errs = append(errs, err) diff --git a/pkg/manager/cert/manager_test.go b/pkg/manager/cert/manager_test.go index bdb5921f..35352cfb 100644 --- a/pkg/manager/cert/manager_test.go +++ b/pkg/manager/cert/manager_test.go @@ -73,9 +73,9 @@ func TestInit(t *testing.T) { { name: "empty", check: func(t *testing.T, cm *CertManager) { - require.Nil(t, cm.ServerTLS()) + require.Nil(t, cm.ServerSQLTLS()) require.Nil(t, cm.ClusterTLS()) - require.Nil(t, cm.PeerTLS()) + require.Nil(t, cm.ServerHTTPTLS()) require.Nil(t, cm.SQLTLS()) }, }, @@ -83,28 +83,34 @@ func TestInit(t *testing.T) { name: "server config", cfg: config.Config{ Security: config.Security{ - ServerTLS: config.TLSConfig{AutoCerts: true}, + ServerSQLTLS: config.TLSConfig{AutoCerts: true}, + ServerHTTPTLS: config.TLSConfig{AutoCerts: true}, + ClusterTLS: config.TLSConfig{AutoCerts: true}, + SQLTLS: config.TLSConfig{AutoCerts: true}, }, }, check: func(t *testing.T, cm *CertManager) { require.Nil(t, cm.ClusterTLS()) - require.Nil(t, cm.PeerTLS()) require.Nil(t, cm.SQLTLS()) - require.NotNil(t, cm.ServerTLS()) + require.NotNil(t, cm.ServerHTTPTLS()) + require.NotNil(t, cm.ServerSQLTLS()) }, }, { name: "client config", cfg: config.Config{ Security: config.Security{ - SQLTLS: config.TLSConfig{SkipCA: true}, + ServerSQLTLS: config.TLSConfig{SkipCA: true}, + ServerHTTPTLS: config.TLSConfig{SkipCA: true}, + ClusterTLS: config.TLSConfig{SkipCA: true}, + SQLTLS: config.TLSConfig{SkipCA: true}, }, }, check: func(t *testing.T, cm *CertManager) { - require.Nil(t, cm.ClusterTLS()) - require.Nil(t, cm.PeerTLS()) - require.Nil(t, cm.ServerTLS()) + require.NotNil(t, cm.ClusterTLS()) require.NotNil(t, cm.SQLTLS()) + require.Nil(t, cm.ServerHTTPTLS()) + require.Nil(t, cm.ServerSQLTLS()) }, }, { @@ -159,7 +165,7 @@ func TestRotate(t *testing.T) { cfg := &config.Config{ Workdir: tmpdir, Security: config.Security{ - ServerTLS: config.TLSConfig{ + ServerSQLTLS: config.TLSConfig{ Cert: certPath, Key: keyPath, }, @@ -270,7 +276,7 @@ func TestRotate(t *testing.T) { } require.NoError(t, certMgr.Init(cfg, lg, nil)) - stls := certMgr.ServerTLS() + stls := certMgr.ServerSQLTLS() ctls := certMgr.SQLTLS() // pre reloading test @@ -335,7 +341,7 @@ func TestBidirectional(t *testing.T) { cfg := &config.Config{ Workdir: tmpdir, Security: config.Security{ - ServerTLS: config.TLSConfig{ + ServerSQLTLS: config.TLSConfig{ Cert: certPath1, Key: keyPath1, CA: caPath2, @@ -350,7 +356,7 @@ func TestBidirectional(t *testing.T) { certMgr := NewCertManager() require.NoError(t, certMgr.Init(cfg, lg, nil)) - stls := certMgr.ServerTLS() + stls := certMgr.ServerSQLTLS() ctls := certMgr.SQLTLS() clientErr, serverErr := connectWithTLS(ctls, stls) require.NoError(t, clientErr) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 7176ac1a..5c6baa8f 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -167,7 +167,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { s.mu.connID++ logger := s.logger.With(zap.Uint64("connID", connID), zap.String("client_addr", conn.RemoteAddr().String()), zap.String("addr", addr)) - clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), + clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerSQLTLS(), s.certMgr.SQLTLS(), s.hsHandler, connID, addr, &backend.BCConfig{ ProxyProtocol: s.mu.proxyProtocol, RequireBackendTLS: s.mu.requireBackendTLS, diff --git a/pkg/server/api/config_test.go b/pkg/server/api/config_test.go index fd957b0f..db678b4c 100644 --- a/pkg/server/api/config_test.go +++ b/pkg/server/api/config_test.go @@ -53,7 +53,7 @@ ignore-wrong-namespace = true [security.server-tls] min-tls-version = '1.1' -[security.peer-tls] +[security.server-http-tls] min-tls-version = '1.1' [security.cluster-tls] @@ -76,7 +76,7 @@ max-backups = 3 doHTTP(t, http.MethodGet, "/api/admin/config?format=json", nil, func(t *testing.T, r *http.Response) { all, err := io.ReadAll(r.Body) require.NoError(t, err) - require.Equal(t, `{"proxy":{"addr":"0.0.0.0:6000","pd-addrs":"127.0.0.1:2379","require-backend-tls":true,"frontend-keepalive":{"enabled":true},"backend-healthy-keepalive":{"enabled":true,"idle":60000000000,"cnt":5,"intvl":3000000000,"timeout":15000000000},"backend-unhealthy-keepalive":{"enabled":true,"idle":10000000000,"cnt":5,"intvl":1000000000,"timeout":5000000000},"graceful-close-conn-timeout":15},"api":{"addr":"0.0.0.0:3080"},"advance":{"ignore-wrong-namespace":true},"security":{"server-tls":{"min-tls-version":"1.1"},"peer-tls":{"min-tls-version":"1.1"},"cluster-tls":{"min-tls-version":"1.1"},"sql-tls":{"min-tls-version":"1.1"}},"metrics":{"metrics-addr":"","metrics-interval":0},"log":{"encoder":"tidb","level":"info","log-file":{"max-size":300,"max-days":3,"max-backups":3}}}`, + require.Equal(t, `{"proxy":{"addr":"0.0.0.0:6000","pd-addrs":"127.0.0.1:2379","require-backend-tls":true,"frontend-keepalive":{"enabled":true},"backend-healthy-keepalive":{"enabled":true,"idle":60000000000,"cnt":5,"intvl":3000000000,"timeout":15000000000},"backend-unhealthy-keepalive":{"enabled":true,"idle":10000000000,"cnt":5,"intvl":1000000000,"timeout":5000000000},"graceful-close-conn-timeout":15},"api":{"addr":"0.0.0.0:3080"},"advance":{"ignore-wrong-namespace":true},"security":{"server-tls":{"min-tls-version":"1.1"},"server-http-tls":{"min-tls-version":"1.1"},"cluster-tls":{"min-tls-version":"1.1"},"sql-tls":{"min-tls-version":"1.1"}},"metrics":{"metrics-addr":"","metrics-interval":0},"log":{"encoder":"tidb","level":"info","log-file":{"max-size":300,"max-days":3,"max-backups":3}}}`, string(regexp.MustCompile(`"workdir":"[^"]+",`).ReplaceAll(all, nil))) require.Equal(t, http.StatusOK, r.StatusCode) }) diff --git a/pkg/server/api/server.go b/pkg/server/api/server.go index c33384e6..daae792f 100644 --- a/pkg/server/api/server.go +++ b/pkg/server/api/server.go @@ -117,7 +117,7 @@ func NewServer(cfg config.API, lg *zap.Logger, } } - if tlscfg := crtmgr.ServerTLS(); tlscfg != nil { + if tlscfg := crtmgr.ServerHTTPTLS(); tlscfg != nil { h.listener = tls.NewListener(h.listener, tlscfg) }