Skip to content

Commit 65feb3f

Browse files
authored
[v16] kube: properly return the reason for connection disruption (#51455)
* kube: properly return the reason for connection disruption (#51398) * kube: properly return the reason for connection disruption There are several cases where connection monitor can terminate an ongoing connection. Iddle timeout, certificate expiring among others are some reasons for the connection to be terminated. For Kubernetes access, the underlying error is never propagated back to the client so they don't receive the reason for the exec session being terminated. This PR fixes that by adding an hook to write the client error response into the connection error channel for clients to be aware. Part of #18496 * handle review comments * handle review comments * fix slog ref
1 parent 938750d commit 65feb3f

File tree

7 files changed

+100
-16
lines changed

7 files changed

+100
-16
lines changed

integration/kube_integration_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) {
11501150
ClientIdleTimeout: types.NewDuration(500 * time.Millisecond),
11511151
},
11521152
disconnectTimeout: 2 * time.Second,
1153+
verifyError: errorContains("Client exceeded idle timeout of"),
11531154
},
11541155
{
11551156
name: "expired cert",
@@ -1158,6 +1159,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) {
11581159
MaxSessionTTL: types.NewDuration(3 * time.Second),
11591160
},
11601161
disconnectTimeout: 6 * time.Second,
1162+
verifyError: errorContains("client certificate expire"),
11611163
},
11621164
}
11631165

@@ -1252,9 +1254,15 @@ func runKubeDisconnectTest(t *testing.T, suite *KubeSuite, tc disconnectTestCase
12521254
tty: true,
12531255
stdin: term,
12541256
})
1255-
require.NoError(t, err)
1257+
require.NoError(t, tc.verifyError(err))
12561258
}()
12571259

1260+
require.Eventually(t, func() bool {
1261+
// wait for the shell prompt
1262+
return strings.Contains(term.AllOutput(), "#")
1263+
}, 5*time.Second, 10*time.Millisecond, "Failed to get shell prompt. "+
1264+
"If this fails, the exec command is likely hanging and never reaching the kind cluster")
1265+
12581266
// lets type something followed by "enter" and then hang the session
12591267
require.NoError(t, enterInput(sessionCtx, term, "echo boring platypus\r\n", ".*boring platypus.*"))
12601268
time.Sleep(tc.disconnectTimeout)

lib/kube/proxy/forwarder.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,9 @@ type authContext struct {
417417
recordingConfig types.SessionRecordingConfig
418418
// clientIdleTimeout sets information on client idle timeout
419419
clientIdleTimeout time.Duration
420+
// clientIdleTimeoutMessage is the message to be displayed to the user
421+
// when the client idle timeout is reached
422+
clientIdleTimeoutMessage string
420423
// disconnectExpiredCert if set, controls the time when the connection
421424
// should be disconnected because the client cert expires
422425
disconnectExpiredCert time.Time
@@ -805,13 +808,14 @@ func (f *Forwarder) setupContext(
805808
}
806809

807810
return &authContext{
808-
clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()),
809-
sessionTTL: sessionTTL,
810-
Context: authCtx,
811-
recordingConfig: recordingConfig,
812-
kubeClusterName: kubeCluster,
813-
certExpires: identity.Expires,
814-
disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref),
811+
clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()),
812+
clientIdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(),
813+
sessionTTL: sessionTTL,
814+
Context: authCtx,
815+
recordingConfig: recordingConfig,
816+
kubeClusterName: kubeCluster,
817+
certExpires: identity.Expires,
818+
disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref),
815819
teleportCluster: teleportClusterClient{
816820
name: teleportClusterName,
817821
remoteAddr: utils.NetAddr{AddrNetwork: "tcp", Addr: req.RemoteAddr},
@@ -1666,6 +1670,8 @@ func (f *Forwarder) exec(authCtx *authContext, w http.ResponseWriter, req *http.
16661670

16671671
return upgradeRequestToRemoteCommandProxy(request,
16681672
func(proxy *remoteCommandProxy) error {
1673+
sess.sendErrStatus = proxy.writeStatus
1674+
16691675
if !sess.isLocalKubernetesCluster {
16701676
// We're forwarding this to another kubernetes_service instance, let it handle multiplexing.
16711677
return f.remoteExec(authCtx, w, req, p, sess, request, proxy)
@@ -2286,6 +2292,8 @@ type clusterSession struct {
22862292
connCtx context.Context
22872293
// connMonitorCancel is the conn monitor connMonitorCancel function.
22882294
connMonitorCancel context.CancelCauseFunc
2295+
// sendErrStatus is a function that sends an error status to the client.
2296+
sendErrStatus func(status *kubeerrors.StatusError) error
22892297
}
22902298

22912299
// close cancels the connection monitor context if available.
@@ -2324,6 +2332,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n
23242332
LockTargets: lockTargets,
23252333
DisconnectExpiredCert: s.disconnectExpiredCert,
23262334
ClientIdleTimeout: s.clientIdleTimeout,
2335+
IdleTimeoutMessage: s.clientIdleTimeoutMessage,
23272336
Clock: s.parent.cfg.Clock,
23282337
Tracker: tc,
23292338
Conn: tc,
@@ -2333,6 +2342,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n
23332342
Entry: s.parent.log,
23342343
Emitter: s.parent.cfg.AuthClient,
23352344
EmitterContext: s.parent.ctx,
2345+
MessageWriter: formatForwardResponseError(s.sendErrStatus),
23362346
})
23372347
if err != nil {
23382348
tc.CloseWithCause(err)
@@ -2694,3 +2704,27 @@ func errorToKubeStatusReason(err error, code int) metav1.StatusReason {
26942704
return metav1.StatusReasonUnknown
26952705
}
26962706
}
2707+
2708+
// formatForwardResponseError formats the error response from the connection
2709+
// monitor to a Kubernetes API error response.
2710+
type formatForwardResponseError func(status *kubeerrors.StatusError) error
2711+
2712+
func (f formatForwardResponseError) WriteString(s string) (int, error) {
2713+
if f == nil {
2714+
return len(s), nil
2715+
}
2716+
err := f(
2717+
&kubeerrors.StatusError{
2718+
ErrStatus: metav1.Status{
2719+
Status: metav1.StatusFailure,
2720+
Code: http.StatusInternalServerError,
2721+
Reason: metav1.StatusReasonInternalError,
2722+
Message: s,
2723+
},
2724+
},
2725+
)
2726+
if err != nil {
2727+
return 0, trace.Wrap(err)
2728+
}
2729+
return len(s), nil
2730+
}

lib/kube/proxy/portforward_spdy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func runPortForwardingHTTPStreams(req portForwardRequest) error {
106106
defer h.Close()
107107

108108
h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout)
109-
conn.SetIdleTimeout(req.idleTimeout)
109+
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))
110110

111111
h.run()
112112
return nil

lib/kube/proxy/portforward_websocket.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error {
9393
},
9494
})
9595

96-
conn.SetIdleTimeout(req.idleTimeout)
96+
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))
9797

9898
// Upgrade the request and create the virtual streams.
9999
_, streams, err := conn.Open(
@@ -355,7 +355,7 @@ func runPortForwardingTunneledHTTPStreams(req portForwardRequest) error {
355355
defer h.Close()
356356

357357
h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout)
358-
spdyConn.SetIdleTimeout(req.idleTimeout)
358+
spdyConn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))
359359

360360
h.run()
361361
return nil

lib/kube/proxy/remotecommand.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"io"
2525
"net/http"
2626
"strings"
27+
"sync"
2728
"time"
2829

2930
"github.com/gravitational/trace"
@@ -157,7 +158,7 @@ func createSPDYStreams(req remoteCommandRequest) (*remoteCommandProxy, error) {
157158
return nil, trace.ConnectionProblem(trace.BadParameter("missing connection"), "missing connection")
158159
}
159160

160-
conn.SetIdleTimeout(req.idleTimeout)
161+
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))
161162

162163
var handler protocolHandler
163164
switch protocol {
@@ -445,23 +446,35 @@ func waitStreamReply(ctx context.Context, replySent <-chan struct{}, notify chan
445446
// v4WriteStatusFunc returns a WriteStatusFunc that marshals a given api Status
446447
// as json in the error channel.
447448
func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
448-
return func(status *apierrors.StatusError) error {
449+
return writeStatusOnceFunc(func(status *apierrors.StatusError) error {
449450
st := status.Status()
450451
data, err := runtime.Encode(globalKubeCodecs.LegacyCodec(), &st)
451452
if err != nil {
452453
return trace.Wrap(err)
453454
}
454455
_, err = stream.Write(data)
455456
return err
456-
}
457+
})
457458
}
458459

459460
func v1WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
460-
return func(status *apierrors.StatusError) error {
461+
return writeStatusOnceFunc(func(status *apierrors.StatusError) error {
461462
if status.Status().Status == metav1.StatusSuccess {
462463
return nil // send error messages
463464
}
464465
_, err := stream.Write([]byte(status.Error()))
465466
return err
467+
})
468+
}
469+
470+
// writeStatusOnceFunc returns a function that only calls f once, and returns the result of the first call.
471+
func writeStatusOnceFunc(f func(status *apierrors.StatusError) error) func(status *apierrors.StatusError) error {
472+
var once sync.Once
473+
var err error
474+
return func(status *apierrors.StatusError) error {
475+
once.Do(func() {
476+
err = f(status)
477+
})
478+
return trace.Wrap(err)
466479
}
467480
}

lib/kube/proxy/remotecommand_websocket.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License.
1919
package proxy
2020

2121
import (
22+
"time"
23+
2224
"github.com/go-logr/logr"
2325
"github.com/gravitational/trace"
2426
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
@@ -110,7 +112,7 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro
110112
},
111113
})
112114

113-
conn.SetIdleTimeout(req.idleTimeout)
115+
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))
114116

115117
negotiatedProtocol, streams, err := conn.Open(
116118
responsewriter.GetOriginal(req.httpResponseWriter),
@@ -163,3 +165,19 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro
163165

164166
return proxy, nil
165167
}
168+
169+
// adjustIdleTimeoutForConn adjusts the idle timeout for the connection
170+
// to be 5 seconds longer than the requested idle timeout.
171+
// This is done to prevent the connection from being closed by the server
172+
// before the connection monitor has a chance to close it and write the
173+
// status code.
174+
// If the idle timeout is 0, this function returns 0 because it means the
175+
// connection will never be closed by the server due to idleness.
176+
func adjustIdleTimeoutForConn(idleTimeout time.Duration) time.Duration {
177+
// If the idle timeout is 0, we don't need to adjust it because it
178+
// means the connection will never be closed by the server due to idleness.
179+
if idleTimeout != 0 {
180+
idleTimeout += 5 * time.Second
181+
}
182+
return idleTimeout
183+
}

lib/srv/monitor.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ type MonitorConfig struct {
238238
Entry log.FieldLogger
239239
// IdleTimeoutMessage is sent to the client when the idle timeout expires.
240240
IdleTimeoutMessage string
241+
// CertificateExpiredMessage is sent to the client when the certificate expires.
242+
CertificateExpiredMessage string
241243
// MessageWriter wraps a channel to send text messages to the client. Use
242244
// for disconnection messages, etc.
243245
MessageWriter io.StringWriter
@@ -417,6 +419,15 @@ func (w *Monitor) start(lockWatch types.Watcher) {
417419

418420
func (w *Monitor) disconnectClientOnExpiredCert() {
419421
reason := fmt.Sprintf("client certificate expired at %v", w.Clock.Now().UTC())
422+
if w.MessageWriter != nil {
423+
msg := w.CertificateExpiredMessage
424+
if msg == "" {
425+
msg = reason
426+
}
427+
if _, err := w.MessageWriter.WriteString(msg); err != nil {
428+
w.Entry.WithError(err).Warn("Failed to send certificate expiration message")
429+
}
430+
}
420431
w.disconnectClient(reason)
421432
}
422433

0 commit comments

Comments
 (0)