diff --git a/api/types/session_tracker.go b/api/types/session_tracker.go index db07ea2578db5..2892db170085c 100644 --- a/api/types/session_tracker.go +++ b/api/types/session_tracker.go @@ -39,6 +39,7 @@ const ( DatabaseSessionKind SessionKind = "db" AppSessionKind SessionKind = "app" WindowsDesktopSessionKind SessionKind = "desktop" + UnknownSessionKind SessionKind = "" ) // SessionParticipantMode is the mode that determines what you can do when you join a session. diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index c48bdf2684c83..40e94526e472a 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -196,30 +196,14 @@ func (a *ServerWithRoles) actionWithExtendedContext(kind, verb string, extendCon // actionForKindSession is a special checker that grants access to session // recordings. It can allow access to a specific recording based on the // `where` section of the user's access rule for kind `session`. -func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) (types.SessionKind, error) { - sessionEnd, err := a.findSessionEndEvent(ctx, sid) - - extendContext := func(ctx *services.Context) error { - ctx.Session = sessionEnd +func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) error { + extendContext := func(servicesCtx *services.Context) error { + sessionEnd, err := a.findSessionEndEvent(ctx, sid) + servicesCtx.Session = sessionEnd return trace.Wrap(err) } - var sessionKind types.SessionKind - switch e := sessionEnd.(type) { - case *apievents.SessionEnd: - sessionKind = types.SSHSessionKind - if e.KubernetesCluster != "" { - sessionKind = types.KubernetesSessionKind - } - case *apievents.DatabaseSessionEnd: - sessionKind = types.DatabaseSessionKind - case *apievents.AppSessionEnd: - sessionKind = types.AppSessionKind - case *apievents.WindowsDesktopSessionEnd: - sessionKind = types.WindowsDesktopSessionKind - } - - return sessionKind, trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext)) + return trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext)) } // localServerAction returns an access denied error if the role is not one of the builtin server roles. @@ -6121,29 +6105,25 @@ func (a *ServerWithRoles) ReplaceRemoteLocks(ctx context.Context, clusterName st // channel if one is encountered. Otherwise the event channel is closed when the stream ends. // The event channel is not closed on error to prevent race conditions in downstream select statements. func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { - createErrorChannel := func(err error) (chan apievents.AuditEvent, chan error) { - e := make(chan error, 1) - e <- trace.Wrap(err) - return nil, e - } - err := a.localServerAction() isTeleportServer := err == nil - var sessionType types.SessionKind - if !isTeleportServer { - var err error - sessionType, err = a.actionForKindSession(ctx, sessionID) - if err != nil { - c, e := make(chan apievents.AuditEvent), make(chan error, 1) - e <- trace.Wrap(err) - return c, e - } + // StreamSessionEvents can be called internally, and when that + // happens we don't want to emit an event or check for permissions. + if isTeleportServer { + return a.alog.StreamSessionEvents(ctx, sessionID, startIndex) } - // StreamSessionEvents can be called internally, and when that happens we don't want to emit an event. - shouldEmitAuditEvent := !isTeleportServer - if shouldEmitAuditEvent { + if err := a.actionForKindSession(ctx, sessionID); err != nil { + c, e := make(chan apievents.AuditEvent), make(chan error, 1) + e <- trace.Wrap(err) + return c, e + } + + // We can only determine the session type after the streaming started. For + // this reason, we delay the emit audit event until the first event or if + // the streaming returns an error. + cb := func(evt apievents.AuditEvent, _ error) { if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ Metadata: apievents.Metadata{ Type: events.SessionRecordingAccessEvent, @@ -6151,14 +6131,34 @@ func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID ses }, SessionID: sessionID.String(), UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(), - SessionType: string(sessionType), + SessionType: string(sessionTypeFromStartEvent(evt)), Format: metadata.SessionRecordingFormatFromContext(ctx), }); err != nil { - return createErrorChannel(err) + log.WithError(err).Errorf("Failed to emit stream session event audit event") } } - return a.alog.StreamSessionEvents(ctx, sessionID, startIndex) + return a.alog.StreamSessionEvents(events.ContextWithSessionStartCallback(ctx, cb), sessionID, startIndex) +} + +// sessionTypeFromStartEvent determines the session type given the session start +// event. +func sessionTypeFromStartEvent(sessionStart apievents.AuditEvent) types.SessionKind { + switch e := sessionStart.(type) { + case *apievents.SessionStart: + if e.KubernetesCluster != "" { + return types.KubernetesSessionKind + } + return types.SSHSessionKind + case *apievents.DatabaseSessionStart: + return types.DatabaseSessionKind + case *apievents.AppSessionStart: + return types.AppSessionKind + case *apievents.WindowsDesktopSessionStart: + return types.WindowsDesktopSessionKind + default: + return types.UnknownSessionKind + } } // CreateApp creates a new application resource. diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 1576986648191..f6ad5315cd6fb 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -2267,7 +2267,29 @@ func TestStreamSessionEvents(t *testing.T) { func TestStreamSessionEvents_SessionType(t *testing.T) { t.Parallel() - srv := newTestTLSServer(t) + authServerConfig := TestAuthServerConfig{ + Dir: t.TempDir(), + Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()), + } + require.NoError(t, authServerConfig.CheckAndSetDefaults()) + + uploader := eventstest.NewMemoryUploader() + localLog, err := events.NewAuditLog(events.AuditLogConfig{ + DataDir: authServerConfig.Dir, + ServerID: authServerConfig.ClusterName, + Clock: authServerConfig.Clock, + UploadHandler: uploader, + }) + require.NoError(t, err) + authServerConfig.AuditLog = localLog + + as, err := NewTestAuthServer(authServerConfig) + require.NoError(t, err) + + srv, err := as.NewTestTLSServer() + require.NoError(t, err) + t.Cleanup(func() { srv.Close() }) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -2278,22 +2300,29 @@ func TestStreamSessionEvents_SessionType(t *testing.T) { identity := TestUser(user.GetName()) clt, err := srv.NewClient(identity) require.NoError(t, err) - sessionID := "44c6cea8-362f-11ea-83aa-125400432324" + sessionID := session.NewID() - // Emitting a session end event will cause the listing to correctly locate - // the recording (even if there might not be a recording file to stream). - require.NoError(t, srv.Auth().EmitAuditEvent(ctx, &apievents.DatabaseSessionEnd{ + streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ + Uploader: uploader, + }) + require.NoError(t, err) + stream, err := streamer.CreateAuditStream(ctx, sessionID) + require.NoError(t, err) + // The event is not required to pass through the auth server, we only need + // the upload to be present. + require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(&apievents.DatabaseSessionStart{ Metadata: apievents.Metadata{ - Type: events.DatabaseSessionEndEvent, - Code: events.DatabaseSessionEndCode, + Type: events.DatabaseSessionStartEvent, + Code: events.DatabaseSessionStartCode, }, SessionMetadata: apievents.SessionMetadata{ - SessionID: sessionID, + SessionID: sessionID.String(), }, - })) + }))) + require.NoError(t, stream.Complete(ctx)) accessedFormat := teleport.PTY - clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), session.ID(sessionID), 0) + clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), sessionID, 0) // Perform the listing an eventually loop to ensure the event is emitted. var searchEvents []apievents.AuditEvent diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 3570171f40996..274c3c65c56a6 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -509,9 +509,23 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID e := make(chan error, 1) c := make(chan apievents.AuditEvent) + sessionStartCh := make(chan apievents.AuditEvent, 1) + if startCb, err := sessionStartCallbackFromContext(ctx); err == nil { + go func() { + evt, ok := <-sessionStartCh + if !ok { + startCb(nil, trace.NotFound("session start event not found")) + return + } + + startCb(evt, nil) + }() + } + rawSession, err := os.CreateTemp(l.playbackDir, string(sessionID)+".stream.tar.*") if err != nil { e <- trace.Wrap(trace.ConvertSystemError(err), "creating temporary stream file") + close(sessionStartCh) return c, e } // The file is still perfectly usable after unlinking it, and the space it's @@ -528,6 +542,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID if err := os.Remove(rawSession.Name()); err != nil { _ = rawSession.Close() e <- trace.Wrap(trace.ConvertSystemError(err), "removing temporary stream file") + close(sessionStartCh) return c, e } @@ -538,6 +553,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID err = trace.NotFound("a recording for session %v was not found", sessionID) } e <- trace.Wrap(err) + close(sessionStartCh) return c, e } l.log.DebugContext(ctx, "Downloaded session to a temporary file for streaming.", @@ -547,6 +563,8 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID go func() { defer rawSession.Close() + defer close(sessionStartCh) + // this shouldn't be necessary as the position should be already 0 (Download // takes an io.WriterAt), but it's better to be safe than sorry if _, err := rawSession.Seek(0, io.SeekStart); err != nil { @@ -557,6 +575,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID protoReader := NewProtoReader(rawSession) defer protoReader.Close() + firstEvent := true for { if ctx.Err() != nil { e <- trace.Wrap(ctx.Err()) @@ -573,6 +592,11 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID return } + if firstEvent { + sessionStartCh <- event + firstEvent = false + } + if event.GetIndex() >= startIndex { select { case c <- event: @@ -667,3 +691,39 @@ func (l *AuditLog) periodicSpaceMonitor() { } } } + +// streamSessionEventsContextKey represent context keys used by +// StreamSessionEvents function. +type streamSessionEventsContextKey string + +const ( + // sessionStartCallbackContextKey is the context key used to store the + // session start callback function. + sessionStartCallbackContextKey streamSessionEventsContextKey = "session-start" +) + +// SessionStartCallback is the function used when streaming reaches the start +// event. If any error, such as session not found, the event will be nil, and +// the error will be set. +type SessionStartCallback func(startEvent apievents.AuditEvent, err error) + +// ContextWithSessionStartCallback returns a context.Context containing a +// session start event callback. +func ContextWithSessionStartCallback(ctx context.Context, cb SessionStartCallback) context.Context { + return context.WithValue(ctx, sessionStartCallbackContextKey, cb) +} + +// sessionStartCallbackFromContext returns the session start callback from +// context.Context. +func sessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) { + if ctx == nil { + return nil, trace.BadParameter("context is nil") + } + + cb, ok := ctx.Value(sessionStartCallbackContextKey).(SessionStartCallback) + if !ok { + return nil, trace.BadParameter("session start callback function was not found in the context") + } + + return cb, nil +} diff --git a/lib/events/auditlog_test.go b/lib/events/auditlog_test.go index b76d27a0ee36a..416373e3e6951 100644 --- a/lib/events/auditlog_test.go +++ b/lib/events/auditlog_test.go @@ -154,6 +154,137 @@ func TestConcurrentStreaming(t *testing.T) { } } +func TestStreamSessionEvents(t *testing.T) { + uploader := eventstest.NewMemoryUploader() + alog, err := events.NewAuditLog(events.AuditLogConfig{ + DataDir: t.TempDir(), + Clock: clockwork.NewFakeClock(), + ServerID: "remote", + UploadHandler: uploader, + }) + require.NoError(t, err) + t.Cleanup(func() { alog.Close() }) + + ctx := context.Background() + sid := session.NewID() + sessionEvents := []apievents.AuditEvent{ + &apievents.DatabaseSessionStart{ + Metadata: apievents.Metadata{ + Type: events.DatabaseSessionStartEvent, + Code: events.DatabaseSessionStartCode, + Index: 0, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: sid.String(), + }, + }, + &apievents.DatabaseSessionEnd{ + Metadata: apievents.Metadata{ + Type: events.DatabaseSessionEndEvent, + Code: events.DatabaseSessionEndCode, + Index: 1, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: sid.String(), + }, + }, + } + + streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ + Uploader: uploader, + }) + require.NoError(t, err) + stream, err := streamer.CreateAuditStream(ctx, sid) + require.NoError(t, err) + for _, event := range sessionEvents { + require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(event))) + } + require.NoError(t, stream.Complete(ctx)) + + type callbackResult struct { + event apievents.AuditEvent + err error + } + + t.Run("Success", func(t *testing.T) { + for name, withCallback := range map[string]bool{ + "WithCallback": true, + "WithoutCallback": false, + } { + t.Run(name, func(t *testing.T) { + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + callbackCh := make(chan callbackResult, 1) + if withCallback { + streamCtx = events.ContextWithSessionStartCallback(streamCtx, func(ae apievents.AuditEvent, err error) { + callbackCh <- callbackResult{ae, err} + }) + } + + ch, _ := alog.StreamSessionEvents(streamCtx, sid, 0) + for _, event := range sessionEvents { + select { + case receivedEvent := <-ch: + require.NotNil(t, receivedEvent) + require.Equal(t, event.GetCode(), receivedEvent.GetCode()) + require.Equal(t, event.GetType(), receivedEvent.GetType()) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to receive session event %q but got nothing", event.GetType()) + } + } + + if withCallback { + select { + case res := <-callbackCh: + require.NoError(t, res.err) + require.Equal(t, sessionEvents[0].GetCode(), res.event.GetCode()) + require.Equal(t, sessionEvents[0].GetType(), res.event.GetType()) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to receive callback result but got nothing") + } + } + }) + } + }) + + t.Run("Error", func(t *testing.T) { + for name, withCallback := range map[string]bool{ + "WithCallback": true, + "WithoutCallback": false, + } { + t.Run(name, func(t *testing.T) { + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + callbackCh := make(chan callbackResult, 1) + if withCallback { + streamCtx = events.ContextWithSessionStartCallback(streamCtx, func(ae apievents.AuditEvent, err error) { + callbackCh <- callbackResult{ae, err} + }) + } + + _, errCh := alog.StreamSessionEvents(streamCtx, session.ID("random"), 0) + select { + case err := <-errCh: + require.Error(t, err) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to get error while stream but got nothing") + } + + if withCallback { + select { + case res := <-callbackCh: + require.Error(t, res.err) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to receive callback result but got nothing") + } + } + }) + } + }) +} + func TestExternalLog(t *testing.T) { m := &eventstest.MockAuditLog{ Emitter: &eventstest.MockRecorderEmitter{},