Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] kube: properly return the reason for connection disruption #51455

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion integration/kube_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) {
ClientIdleTimeout: types.NewDuration(500 * time.Millisecond),
},
disconnectTimeout: 2 * time.Second,
verifyError: errorContains("Client exceeded idle timeout of"),
},
{
name: "expired cert",
Expand All @@ -1158,6 +1159,7 @@ func testKubeDisconnect(t *testing.T, suite *KubeSuite) {
MaxSessionTTL: types.NewDuration(3 * time.Second),
},
disconnectTimeout: 6 * time.Second,
verifyError: errorContains("client certificate expire"),
},
}

Expand Down Expand Up @@ -1252,9 +1254,15 @@ func runKubeDisconnectTest(t *testing.T, suite *KubeSuite, tc disconnectTestCase
tty: true,
stdin: term,
})
require.NoError(t, err)
require.NoError(t, tc.verifyError(err))
}()

require.Eventually(t, func() bool {
// wait for the shell prompt
return strings.Contains(term.AllOutput(), "#")
}, 5*time.Second, 10*time.Millisecond, "Failed to get shell prompt. "+
"If this fails, the exec command is likely hanging and never reaching the kind cluster")

// lets type something followed by "enter" and then hang the session
require.NoError(t, enterInput(sessionCtx, term, "echo boring platypus\r\n", ".*boring platypus.*"))
time.Sleep(tc.disconnectTimeout)
Expand Down
48 changes: 41 additions & 7 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,9 @@ type authContext struct {
recordingConfig types.SessionRecordingConfig
// clientIdleTimeout sets information on client idle timeout
clientIdleTimeout time.Duration
// clientIdleTimeoutMessage is the message to be displayed to the user
// when the client idle timeout is reached
clientIdleTimeoutMessage string
// disconnectExpiredCert if set, controls the time when the connection
// should be disconnected because the client cert expires
disconnectExpiredCert time.Time
Expand Down Expand Up @@ -805,13 +808,14 @@ func (f *Forwarder) setupContext(
}

return &authContext{
clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()),
sessionTTL: sessionTTL,
Context: authCtx,
recordingConfig: recordingConfig,
kubeClusterName: kubeCluster,
certExpires: identity.Expires,
disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref),
clientIdleTimeout: roles.AdjustClientIdleTimeout(netConfig.GetClientIdleTimeout()),
clientIdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(),
sessionTTL: sessionTTL,
Context: authCtx,
recordingConfig: recordingConfig,
kubeClusterName: kubeCluster,
certExpires: identity.Expires,
disconnectExpiredCert: authCtx.GetDisconnectCertExpiry(authPref),
teleportCluster: teleportClusterClient{
name: teleportClusterName,
remoteAddr: utils.NetAddr{AddrNetwork: "tcp", Addr: req.RemoteAddr},
Expand Down Expand Up @@ -1666,6 +1670,8 @@ func (f *Forwarder) exec(authCtx *authContext, w http.ResponseWriter, req *http.

return upgradeRequestToRemoteCommandProxy(request,
func(proxy *remoteCommandProxy) error {
sess.sendErrStatus = proxy.writeStatus

if !sess.isLocalKubernetesCluster {
// We're forwarding this to another kubernetes_service instance, let it handle multiplexing.
return f.remoteExec(authCtx, w, req, p, sess, request, proxy)
Expand Down Expand Up @@ -2286,6 +2292,8 @@ type clusterSession struct {
connCtx context.Context
// connMonitorCancel is the conn monitor connMonitorCancel function.
connMonitorCancel context.CancelCauseFunc
// sendErrStatus is a function that sends an error status to the client.
sendErrStatus func(status *kubeerrors.StatusError) error
}

// close cancels the connection monitor context if available.
Expand Down Expand Up @@ -2324,6 +2332,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n
LockTargets: lockTargets,
DisconnectExpiredCert: s.disconnectExpiredCert,
ClientIdleTimeout: s.clientIdleTimeout,
IdleTimeoutMessage: s.clientIdleTimeoutMessage,
Clock: s.parent.cfg.Clock,
Tracker: tc,
Conn: tc,
Expand All @@ -2333,6 +2342,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (n
Entry: s.parent.log,
Emitter: s.parent.cfg.AuthClient,
EmitterContext: s.parent.ctx,
MessageWriter: formatForwardResponseError(s.sendErrStatus),
})
if err != nil {
tc.CloseWithCause(err)
Expand Down Expand Up @@ -2694,3 +2704,27 @@ func errorToKubeStatusReason(err error, code int) metav1.StatusReason {
return metav1.StatusReasonUnknown
}
}

// formatForwardResponseError formats the error response from the connection
// monitor to a Kubernetes API error response.
type formatForwardResponseError func(status *kubeerrors.StatusError) error

func (f formatForwardResponseError) WriteString(s string) (int, error) {
if f == nil {
return len(s), nil
}
err := f(
&kubeerrors.StatusError{
ErrStatus: metav1.Status{
Status: metav1.StatusFailure,
Code: http.StatusInternalServerError,
Reason: metav1.StatusReasonInternalError,
Message: s,
},
},
)
if err != nil {
return 0, trace.Wrap(err)
}
return len(s), nil
}
2 changes: 1 addition & 1 deletion lib/kube/proxy/portforward_spdy.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func runPortForwardingHTTPStreams(req portForwardRequest) error {
defer h.Close()

h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout)
conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

h.run()
return nil
Expand Down
4 changes: 2 additions & 2 deletions lib/kube/proxy/portforward_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error {
},
})

conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

// Upgrade the request and create the virtual streams.
_, streams, err := conn.Open(
Expand Down Expand Up @@ -355,7 +355,7 @@ func runPortForwardingTunneledHTTPStreams(req portForwardRequest) error {
defer h.Close()

h.Debugf("Setting port forwarding streaming connection idle timeout to %s.", req.idleTimeout)
spdyConn.SetIdleTimeout(req.idleTimeout)
spdyConn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

h.run()
return nil
Expand Down
21 changes: 17 additions & 4 deletions lib/kube/proxy/remotecommand.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -157,7 +158,7 @@ func createSPDYStreams(req remoteCommandRequest) (*remoteCommandProxy, error) {
return nil, trace.ConnectionProblem(trace.BadParameter("missing connection"), "missing connection")
}

conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

var handler protocolHandler
switch protocol {
Expand Down Expand Up @@ -445,23 +446,35 @@ func waitStreamReply(ctx context.Context, replySent <-chan struct{}, notify chan
// v4WriteStatusFunc returns a WriteStatusFunc that marshals a given api Status
// as json in the error channel.
func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
return writeStatusOnceFunc(func(status *apierrors.StatusError) error {
st := status.Status()
data, err := runtime.Encode(globalKubeCodecs.LegacyCodec(), &st)
if err != nil {
return trace.Wrap(err)
}
_, err = stream.Write(data)
return err
}
})
}

func v1WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
return writeStatusOnceFunc(func(status *apierrors.StatusError) error {
if status.Status().Status == metav1.StatusSuccess {
return nil // send error messages
}
_, err := stream.Write([]byte(status.Error()))
return err
})
}

// writeStatusOnceFunc returns a function that only calls f once, and returns the result of the first call.
func writeStatusOnceFunc(f func(status *apierrors.StatusError) error) func(status *apierrors.StatusError) error {
var once sync.Once
var err error
return func(status *apierrors.StatusError) error {
once.Do(func() {
err = f(status)
})
return trace.Wrap(err)
}
}
20 changes: 19 additions & 1 deletion lib/kube/proxy/remotecommand_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
package proxy

import (
"time"

"github.com/go-logr/logr"
"github.com/gravitational/trace"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
Expand Down Expand Up @@ -110,7 +112,7 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro
},
})

conn.SetIdleTimeout(req.idleTimeout)
conn.SetIdleTimeout(adjustIdleTimeoutForConn(req.idleTimeout))

negotiatedProtocol, streams, err := conn.Open(
responsewriter.GetOriginal(req.httpResponseWriter),
Expand Down Expand Up @@ -163,3 +165,19 @@ func createWebSocketStreams(req remoteCommandRequest) (*remoteCommandProxy, erro

return proxy, nil
}

// adjustIdleTimeoutForConn adjusts the idle timeout for the connection
// to be 5 seconds longer than the requested idle timeout.
// This is done to prevent the connection from being closed by the server
// before the connection monitor has a chance to close it and write the
// status code.
// If the idle timeout is 0, this function returns 0 because it means the
// connection will never be closed by the server due to idleness.
func adjustIdleTimeoutForConn(idleTimeout time.Duration) time.Duration {
// If the idle timeout is 0, we don't need to adjust it because it
// means the connection will never be closed by the server due to idleness.
if idleTimeout != 0 {
idleTimeout += 5 * time.Second
}
return idleTimeout
}
11 changes: 11 additions & 0 deletions lib/srv/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ type MonitorConfig struct {
Entry log.FieldLogger
// IdleTimeoutMessage is sent to the client when the idle timeout expires.
IdleTimeoutMessage string
// CertificateExpiredMessage is sent to the client when the certificate expires.
CertificateExpiredMessage string
// MessageWriter wraps a channel to send text messages to the client. Use
// for disconnection messages, etc.
MessageWriter io.StringWriter
Expand Down Expand Up @@ -417,6 +419,15 @@ func (w *Monitor) start(lockWatch types.Watcher) {

func (w *Monitor) disconnectClientOnExpiredCert() {
reason := fmt.Sprintf("client certificate expired at %v", w.Clock.Now().UTC())
if w.MessageWriter != nil {
msg := w.CertificateExpiredMessage
if msg == "" {
msg = reason
}
if _, err := w.MessageWriter.WriteString(msg); err != nil {
w.Entry.WithError(err).Warn("Failed to send certificate expiration message")
}
}
w.disconnectClient(reason)
}

Expand Down
Loading