diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index a6497441fecb7..b8ba4e0d3a569 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -1117,6 +1117,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) { ClientIdleTimeout: types.NewDuration(500 * time.Millisecond), }, disconnectTimeout: 2 * time.Second, + verifyError: errorContains("Client exceeded idle timeout of"), }, { name: "expired cert", @@ -1125,6 +1126,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) { MaxSessionTTL: types.NewDuration(3 * time.Second), }, disconnectTimeout: 6 * time.Second, + verifyError: errorContains("client certificate expire"), }, } @@ -1219,9 +1221,15 @@ func runKubeDisconnectTest(t *testing.T, suite *KubeSuite, tc disconnectTestCase tty: true, stdin: term, }) - require.NoError(t, err) + require.NoError(t, tc.verifyError(err)) }() + require.Eventually(t, func() bool { + // wait for the shell prompt + return strings.Contains(term.AllOutput(), "#") + }, 5*time.Second, 10*time.Millisecond, "Failed to get shell prompt. "+ + "If this fails, the exec command is likely hanging and never reaching the kind cluster") + // lets type something followed by "enter" and then hang the session require.NoError(t, enterInput(sessionCtx, term, "echo boring platypus\r\n", ".*boring platypus.*")) time.Sleep(tc.disconnectTimeout) diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 6ad431886bb1f..c9dd8ee4fc7bf 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -399,6 +399,9 @@ type authContext struct { recordingConfig types.SessionRecordingConfig // clientIdleTimeout sets information on client idle timeout clientIdleTimeout time.Duration + // clientIdleTimeoutMessage is the message to be displayed to the user + // when the client idle timeout is reached + clientIdleTimeoutMessage string // disconnectExpiredCert if set, controls the time when the connection // should be disconnected because the client cert expires disconnectExpiredCert time.Time @@ -787,13 +790,14 @@ func (f *Forwarder) setupContext( } return &authContext{ - clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()), - sessionTTL: sessionTTL, - Context: authCtx, - recordingConfig: recordingConfig, - kubeClusterName: kubeCluster, - certExpires: identity.Expires, - disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref), + clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()), + clientIdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(), + sessionTTL: sessionTTL, + Context: authCtx, + recordingConfig: recordingConfig, + kubeClusterName: kubeCluster, + certExpires: identity.Expires, + disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref), teleportCluster: teleportClusterClient{ name: teleportClusterName, remoteAddr: utils.NetAddr{AddrNetwork: "tcp", Addr: req.RemoteAddr}, @@ -1648,6 +1652,8 @@ func (f *Forwarder) exec(authCtx *authContext, w http.ResponseWriter, req *http. return upgradeRequestToRemoteCommandProxy(request, func(proxy *remoteCommandProxy) error { + sess.sendErrStatus = proxy.writeStatus + if !sess.isLocalKubernetesCluster { // We're forwarding this to another kubernetes_service instance, let it handle multiplexing. return f.remoteExec(authCtx, w, req, p, sess, request, proxy) @@ -2166,6 +2172,8 @@ type clusterSession struct { connCtx context.Context // connMonitorCancel is the conn monitor connMonitorCancel function. connMonitorCancel context.CancelCauseFunc + // sendErrStatus is a function that sends an error status to the client. + sendErrStatus func(status *kubeerrors.StatusError) error } // close cancels the connection monitor context if available. @@ -2204,6 +2212,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n LockTargets: lockTargets, DisconnectExpiredCert: s.disconnectExpiredCert, ClientIdleTimeout: s.clientIdleTimeout, + IdleTimeoutMessage: s.clientIdleTimeoutMessage, Clock: s.parent.cfg.Clock, Tracker: tc, Conn: tc, @@ -2213,6 +2222,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n Entry: s.parent.log, Emitter: s.parent.cfg.AuthClient, EmitterContext: s.parent.ctx, + MessageWriter: formatForwardResponseError(s.sendErrStatus), }) if err != nil { tc.CloseWithCause(err) @@ -2549,3 +2559,27 @@ func errorToKubeStatusReason(err error, code int) metav1.StatusReason { return metav1.StatusReasonUnknown } } + +// formatForwardResponseError formats the error response from the connection +// monitor to a Kubernetes API error response. +type formatForwardResponseError func(status *kubeerrors.StatusError) error + +func (f formatForwardResponseError) WriteString(s string) (int, error) { + if f == nil { + return len(s), nil + } + err := f( + &kubeerrors.StatusError{ + ErrStatus: metav1.Status{ + Status: metav1.StatusFailure, + Code: http.StatusInternalServerError, + Reason: metav1.StatusReasonInternalError, + Message: s, + }, + }, + ) + if err != nil { + return 0, trace.Wrap(err) + } + return len(s), nil +} diff --git a/lib/kube/proxy/portforward_spdy.go b/lib/kube/proxy/portforward_spdy.go index 20847382c0eae..8fc6565d21e72 100644 --- a/lib/kube/proxy/portforward_spdy.go +++ b/lib/kube/proxy/portforward_spdy.go @@ -106,7 +106,7 @@ func runPortForwardingHTTPStreams(req portForwardRequest) error { defer h.Close() h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout) - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) h.run() return nil diff --git a/lib/kube/proxy/portforward_websocket.go b/lib/kube/proxy/portforward_websocket.go index 5cba5d99a6913..960941ac616df 100644 --- a/lib/kube/proxy/portforward_websocket.go +++ b/lib/kube/proxy/portforward_websocket.go @@ -89,7 +89,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error { }, }) - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) // Upgrade the request and create the virtual streams. _, streams, err := conn.Open( diff --git a/lib/kube/proxy/remotecommand.go b/lib/kube/proxy/remotecommand.go index 44431a3efd9cb..72e267f82f1aa 100644 --- a/lib/kube/proxy/remotecommand.go +++ b/lib/kube/proxy/remotecommand.go @@ -24,6 +24,7 @@ import ( "io" "net/http" "strings" + "sync" "time" "github.com/gravitational/trace" @@ -154,7 +155,7 @@ func createSPDYStreams(req remoteCommandRequest) (*remoteCommandProxy, error) { return nil, trace.ConnectionProblem(trace.BadParameter("missing connection"), "missing connection") } - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) var handler protocolHandler switch protocol { @@ -442,7 +443,7 @@ func waitStreamReply(ctx context.Context, replySent <-chan struct{}, notify chan // v4WriteStatusFunc returns a WriteStatusFunc that marshals a given api Status // as json in the error channel. func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error { - return func(status *apierrors.StatusError) error { + return writeStatusOnceFunc(func(status *apierrors.StatusError) error { st := status.Status() data, err := runtime.Encode(globalKubeCodecs.LegacyCodec(), &st) if err != nil { @@ -450,15 +451,27 @@ func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) err } _, err = stream.Write(data) return err - } + }) } func v1WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error { - return func(status *apierrors.StatusError) error { + return writeStatusOnceFunc(func(status *apierrors.StatusError) error { if status.Status().Status == metav1.StatusSuccess { return nil // send error messages } _, err := stream.Write([]byte(status.Error())) return err + }) +} + +// writeStatusOnceFunc returns a function that only calls f once, and returns the result of the first call. +func writeStatusOnceFunc(f func(status *apierrors.StatusError) error) func(status *apierrors.StatusError) error { + var once sync.Once + var err error + return func(status *apierrors.StatusError) error { + once.Do(func() { + err = f(status) + }) + return trace.Wrap(err) } } diff --git a/lib/kube/proxy/remotecommand_websocket.go b/lib/kube/proxy/remotecommand_websocket.go index abc5d3f446fdf..cb2c50e9efcb0 100644 --- a/lib/kube/proxy/remotecommand_websocket.go +++ b/lib/kube/proxy/remotecommand_websocket.go @@ -19,6 +19,8 @@ limitations under the License. package proxy import ( + "time" + "github.com/go-logr/logr" "github.com/gravitational/trace" "k8s.io/apimachinery/pkg/util/httpstream/wsstream" @@ -110,7 +112,7 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro }, }) - conn.SetIdleTimeout(req.idleTimeout) + conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout)) negotiatedProtocol, streams, err := conn.Open( responsewriter.GetOriginal(req.httpResponseWriter), @@ -163,3 +165,19 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro return proxy, nil } + +// adjustIdleTimeoutForConn adjusts the idle timeout for the connection +// to be 5 seconds longer than the requested idle timeout. +// This is done to prevent the connection from being closed by the server +// before the connection monitor has a chance to close it and write the +// status code. +// If the idle timeout is 0, this function returns 0 because it means the +// connection will never be closed by the server due to idleness. +func adjustIdleTimeoutForConn(idleTimeout time.Duration) time.Duration { + // If the idle timeout is 0, we don't need to adjust it because it + // means the connection will never be closed by the server due to idleness. + if idleTimeout != 0 { + idleTimeout += 5 * time.Second + } + return idleTimeout +} diff --git a/lib/srv/monitor.go b/lib/srv/monitor.go index 183aeba78b7e8..88ccc21d25980 100644 --- a/lib/srv/monitor.go +++ b/lib/srv/monitor.go @@ -238,6 +238,8 @@ type MonitorConfig struct { Entry log.FieldLogger // IdleTimeoutMessage is sent to the client when the idle timeout expires. IdleTimeoutMessage string + // CertificateExpiredMessage is sent to the client when the certificate expires. + CertificateExpiredMessage string // MessageWriter wraps a channel to send text messages to the client. Use // for disconnection messages, etc. MessageWriter io.StringWriter @@ -417,6 +419,15 @@ func (w *Monitor) start(lockWatch types.Watcher) { func (w *Monitor) disconnectClientOnExpiredCert() { reason := fmt.Sprintf("client certificate expired at %v", w.Clock.Now().UTC()) + if w.MessageWriter != nil { + msg := w.CertificateExpiredMessage + if msg == "" { + msg = reason + } + if _, err := w.MessageWriter.WriteString(msg); err != nil { + w.Logger.WarnContext(w.Context, "Failed to send certificate expiration message", "error", err) + } + } w.disconnectClient(reason) }