diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 968c50383734b..45fdecac8c31b 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -98,7 +98,6 @@ import ( "github.com/gravitational/teleport/lib/secret" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/srv/desktop/tdp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" @@ -188,10 +187,6 @@ type Handler struct { // an authenticated websocket so unauthenticated sockets dont get left // open. wsIODeadline time.Duration - - // withheldMessages is a list of any messages that came from the browser which were - // withheld while the user was performing MFA. - withheldMessages []tdp.Message } // HandlerOption is a functional argument - an option that can be passed diff --git a/lib/web/desktop.go b/lib/web/desktop.go index d763562089eb6..858c6a020a613 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -175,8 +175,10 @@ func (h *Handler) createDesktopConnection( return sendTDPError(err) } + // Holds any messages withheld while issuing certs. + var withheld []tdp.Message // Issue certificate for the user/desktop combination and perform MFA ceremony if required. - certs, err := h.issueCerts(ctx, ws, sctx, mfaRequired, certsReq) + certs, err := h.issueCerts(ctx, ws, sctx, mfaRequired, certsReq, &withheld) if err != nil { return sendTDPError(err) } @@ -222,11 +224,15 @@ func (h *Handler) createDesktopConnection( if err != nil { return sendTDPError(err) } - for _, msg := range h.withheldMessages { + for _, msg := range withheld { + log.Debugf("Sending withheld message: %v", msg) if err := tdpConn.WriteMessage(msg); err != nil { return sendTDPError(err) } } + // nil out the slice so we don't hang on to these messages + // for the rest of the connection + withheld = nil // proxyWebsocketConn hangs here until connection is closed handleProxyWebsocketConnErr( @@ -314,9 +320,10 @@ func (h *Handler) issueCerts( sctx *SessionContext, mfaRequired bool, certsReq *proto.UserCertsRequest, + withheld *[]tdp.Message, ) (certs *proto.Certs, err error) { if mfaRequired { - certs, err = h.performMFACeremony(ctx, ws, sctx, certsReq) + certs, err = h.performMFACeremony(ctx, ws, sctx, certsReq, withheld) if err != nil { return nil, trace.Wrap(err) } @@ -363,6 +370,7 @@ func (h *Handler) performMFACeremony( ws *websocket.Conn, sctx *SessionContext, certsReq *proto.UserCertsRequest, + withheld *[]tdp.Message, ) (_ *proto.Certs, err error) { ctx, span := h.tracer.Start(ctx, "desktop/performMFACeremony") defer func() { @@ -413,7 +421,7 @@ func (h *Handler) performMFACeremony( if err != nil { return nil, trace.Wrap(err) } - h.withheldMessages = append(h.withheldMessages, msg) + *withheld = append(*withheld, msg) continue }