From 597cd3fc64780e68c0134056edb53505862ad3b0 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Fri, 24 Jan 2025 10:35:08 +0000 Subject: [PATCH 1/2] kube: properly return the reason for connection disruption (#51398) * kube: properly return the reason for connection disruption There are several cases where connection monitor can terminate an ongoing connection. Iddle timeout, certificate expiring among others are some reasons for the connection to be terminated. For Kubernetes access, the underlying error is never propagated back to the client so they don't receive the reason for the exec session being terminated. This PR fixes that by adding an hook to write the client error response into the connection error channel for clients to be aware. Part of https://github.com/gravitational/teleport/issues/18496 * handle review comments * handle review comments --- integration/kube_integration_test.go | 10 ++++- lib/kube/proxy/forwarder.go | 48 +++++++++++++++++++---- lib/kube/proxy/portforward_spdy.go | 2 +- lib/kube/proxy/portforward_websocket.go | 4 +- lib/kube/proxy/remotecommand.go | 21 ++++++++-- lib/kube/proxy/remotecommand_websocket.go | 20 +++++++++- lib/srv/monitor.go | 11 ++++++ 7 files changed, 100 insertions(+), 16 deletions(-) diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 27ad39e51a13a..a6854a8d81748 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -1150,6 +1150,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) { ClientIdleTimeout: types.NewDuration(500 * time.Millisecond), }, disconnectTimeout: 2 * time.Second, + verifyError: errorContains("Client exceeded idle timeout of"), }, { name: "expired cert", @@ -1158,6 +1159,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) { MaxSessionTTL: types.NewDuration(3 * time.Second), }, disconnectTimeout: 6 * time.Second, + verifyError: errorContains("client certificate expire"), }, } @@ -1252,9 +1254,15 @@ func runKubeDisconnectTest(t *testing.T, suite *KubeSuite, tc disconnectTestCase tty: true, stdin: term, }) - require.NoError(t, err) + require.NoError(t, tc.verifyError(err)) }() + require.Eventually(t, func() bool { + // wait for the shell prompt + return strings.Contains(term.AllOutput(), "#") + }, 5*time.Second, 10*time.Millisecond, "Failed to get shell prompt. "+ + "If this fails, the exec command is likely hanging and never reaching the kind cluster") + // lets type something followed by "enter" and then hang the session require.NoError(t, enterInput(sessionCtx, term, "echo boring platypus\r\n", ".*boring platypus.*")) time.Sleep(tc.disconnectTimeout) diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 432a971703caa..6a30da7e59263 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -417,6 +417,9 @@ type authContext struct { recordingConfig types.SessionRecordingConfig // clientIdleTimeout sets information on client idle timeout clientIdleTimeout time.Duration + // clientIdleTimeoutMessage is the message to be displayed to the user + // when the client idle timeout is reached + clientIdleTimeoutMessage string // disconnectExpiredCert if set, controls the time when the connection // should be disconnected because the client cert expires disconnectExpiredCert time.Time @@ -805,13 +808,14 @@ func (f *Forwarder) setupContext( } return &authContext{ - clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()), - sessionTTL: sessionTTL, - Context: authCtx, - recordingConfig: recordingConfig, - kubeClusterName: kubeCluster, - certExpires: identity.Expires, - disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref), + clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()), + clientIdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(), + sessionTTL: sessionTTL, + Context: authCtx, + recordingConfig: recordingConfig, + kubeClusterName: kubeCluster, + certExpires: identity.Expires, + disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref), teleportCluster: teleportClusterClient{ name: teleportClusterName, remoteAddr: utils.NetAddr{AddrNetwork: "tcp", Addr: req.RemoteAddr}, @@ -1666,6 +1670,8 @@ func (f *Forwarder) exec(authCtx *authContext, w http.ResponseWriter, req *http. return upgradeRequestToRemoteCommandProxy(request, func(proxy *remoteCommandProxy) error { + sess.sendErrStatus = proxy.writeStatus + if !sess.isLocalKubernetesCluster { // We're forwarding this to another kubernetes_service instance, let it handle multiplexing. return f.remoteExec(authCtx, w, req, p, sess, request, proxy) @@ -2286,6 +2292,8 @@ type clusterSession struct { connCtx context.Context // connMonitorCancel is the conn monitor connMonitorCancel function. connMonitorCancel context.CancelCauseFunc + // sendErrStatus is a function that sends an error status to the client. + sendErrStatus func(status *kubeerrors.StatusError) error } // close cancels the connection monitor context if available. @@ -2324,6 +2332,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n LockTargets: lockTargets, DisconnectExpiredCert: s.disconnectExpiredCert, ClientIdleTimeout: s.clientIdleTimeout, + IdleTimeoutMessage: s.clientIdleTimeoutMessage, Clock: s.parent.cfg.Clock, Tracker: tc, Conn: tc, @@ -2333,6 +2342,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n Entry: s.parent.log, Emitter: s.parent.cfg.AuthClient, EmitterContext: s.parent.ctx, + MessageWriter: formatForwardResponseError(s.sendErrStatus), }) if err != nil { tc.CloseWithCause(err) @@ -2694,3 +2704,27 @@ func errorToKubeStatusReason(err error, code int) metav1.StatusReason { return metav1.StatusReasonUnknown } } + +// formatForwardResponseError formats the error response from the connection +// monitor to a Kubernetes API error response. +type formatForwardResponseError func(status *kubeerrors.StatusError) error + +func (f formatForwardResponseError) WriteString(s string) (int, error) { + if f == nil { + return len(s), nil + } + err := f( + &kubeerrors.StatusError{ + ErrStatus: metav1.Status{ + Status: metav1.StatusFailure, + Code: http.StatusInternalServerError, + Reason: metav1.StatusReasonInternalError, + Message: s, + }, + }, + ) + if err != nil { + return 0, trace.Wrap(err) + } + return len(s), nil +} diff --git a/lib/kube/proxy/portforward_spdy.go b/lib/kube/proxy/portforward_spdy.go index 20847382c0eae..8fc6565d21e72 100644 --- a/lib/kube/proxy/portforward_spdy.go +++ b/lib/kube/proxy/portforward_spdy.go @@ -106,7 +106,7 @@ func runPortForwardingHTTPStreams(req portForwardRequest) error { defer h.Close() h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout) - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) h.run() return nil diff --git a/lib/kube/proxy/portforward_websocket.go b/lib/kube/proxy/portforward_websocket.go index 2186f4632f59e..9f48539551d1f 100644 --- a/lib/kube/proxy/portforward_websocket.go +++ b/lib/kube/proxy/portforward_websocket.go @@ -93,7 +93,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error { }, }) - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) // Upgrade the request and create the virtual streams. _, streams, err := conn.Open( @@ -355,7 +355,7 @@ func runPortForwardingTunneledHTTPStreams(req portForwardRequest) error { defer h.Close() h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout) - spdyConn.SetIdleTimeout(req.idleTimeout) + spdyConn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) h.run() return nil diff --git a/lib/kube/proxy/remotecommand.go b/lib/kube/proxy/remotecommand.go index 09a9c868b43ca..0184f3e78c6ec 100644 --- a/lib/kube/proxy/remotecommand.go +++ b/lib/kube/proxy/remotecommand.go @@ -24,6 +24,7 @@ import ( "io" "net/http" "strings" + "sync" "time" "github.com/gravitational/trace" @@ -157,7 +158,7 @@ func createSPDYStreams(req remoteCommandRequest) (*remoteCommandProxy, error) { return nil, trace.ConnectionProblem(trace.BadParameter("missing connection"), "missing connection") } - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) var handler protocolHandler switch protocol { @@ -445,7 +446,7 @@ func waitStreamReply(ctx context.Context, replySent <-chan struct{}, notify chan // v4WriteStatusFunc returns a WriteStatusFunc that marshals a given api Status // as json in the error channel. func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error { - return func(status *apierrors.StatusError) error { + return writeStatusOnceFunc(func(status *apierrors.StatusError) error { st := status.Status() data, err := runtime.Encode(globalKubeCodecs.LegacyCodec(), &st) if err != nil { @@ -453,15 +454,27 @@ func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) err } _, err = stream.Write(data) return err - } + }) } func v1WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error { - return func(status *apierrors.StatusError) error { + return writeStatusOnceFunc(func(status *apierrors.StatusError) error { if status.Status().Status == metav1.StatusSuccess { return nil // send error messages } _, err := stream.Write([]byte(status.Error())) return err + }) +} + +// writeStatusOnceFunc returns a function that only calls f once, and returns the result of the first call. +func writeStatusOnceFunc(f func(status *apierrors.StatusError) error) func(status *apierrors.StatusError) error { + var once sync.Once + var err error + return func(status *apierrors.StatusError) error { + once.Do(func() { + err = f(status) + }) + return trace.Wrap(err) } } diff --git a/lib/kube/proxy/remotecommand_websocket.go b/lib/kube/proxy/remotecommand_websocket.go index abc5d3f446fdf..cb2c50e9efcb0 100644 --- a/lib/kube/proxy/remotecommand_websocket.go +++ b/lib/kube/proxy/remotecommand_websocket.go @@ -19,6 +19,8 @@ limitations under the License. package proxy import ( + "time" + "github.com/go-logr/logr" "github.com/gravitational/trace" "k8s.io/apimachinery/pkg/util/httpstream/wsstream" @@ -110,7 +112,7 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro }, }) - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) negotiatedProtocol, streams, err := conn.Open( responsewriter.GetOriginal(req.httpResponseWriter), @@ -163,3 +165,19 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro return proxy, nil } + +// adjustIdleTimeoutForConn adjusts the idle timeout for the connection +// to be 5 seconds longer than the requested idle timeout. +// This is done to prevent the connection from being closed by the server +// before the connection monitor has a chance to close it and write the +// status code. +// If the idle timeout is 0, this function returns 0 because it means the +// connection will never be closed by the server due to idleness. +func adjustIdleTimeoutForConn(idleTimeout time.Duration) time.Duration { + // If the idle timeout is 0, we don't need to adjust it because it + // means the connection will never be closed by the server due to idleness. + if idleTimeout != 0 { + idleTimeout += 5 * time.Second + } + return idleTimeout +} diff --git a/lib/srv/monitor.go b/lib/srv/monitor.go index 183aeba78b7e8..88ccc21d25980 100644 --- a/lib/srv/monitor.go +++ b/lib/srv/monitor.go @@ -238,6 +238,8 @@ type MonitorConfig struct { Entry log.FieldLogger // IdleTimeoutMessage is sent to the client when the idle timeout expires. IdleTimeoutMessage string + // CertificateExpiredMessage is sent to the client when the certificate expires. + CertificateExpiredMessage string // MessageWriter wraps a channel to send text messages to the client. Use // for disconnection messages, etc. MessageWriter io.StringWriter @@ -417,6 +419,15 @@ func (w *Monitor) start(lockWatch types.Watcher) { func (w *Monitor) disconnectClientOnExpiredCert() { reason := fmt.Sprintf("client certificate expired at %v", w.Clock.Now().UTC()) + if w.MessageWriter != nil { + msg := w.CertificateExpiredMessage + if msg == "" { + msg = reason + } + if _, err := w.MessageWriter.WriteString(msg); err != nil { + w.Logger.WarnContext(w.Context, "Failed to send certificate expiration message", "error", err) + } + } w.disconnectClient(reason) } From ba9e1c00a63be48c012b9fe2a6e82efc66ffb8c0 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Fri, 24 Jan 2025 13:42:02 +0000 Subject: [PATCH 2/2] fix slog ref --- lib/srv/monitor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/srv/monitor.go b/lib/srv/monitor.go index 88ccc21d25980..077f03efc87bd 100644 --- a/lib/srv/monitor.go +++ b/lib/srv/monitor.go @@ -425,7 +425,7 @@ func (w *Monitor) disconnectClientOnExpiredCert() { msg = reason } if _, err := w.MessageWriter.WriteString(msg); err != nil { - w.Logger.WarnContext(w.Context, "Failed to send certificate expiration message", "error", err) + w.Entry.WithError(err).Warn("Failed to send certificate expiration message") } } w.disconnectClient(reason)