diff --git a/constants.go b/constants.go index 34af133393f7b..1c9ce41f4350f 100644 --- a/constants.go +++ b/constants.go @@ -46,9 +46,6 @@ const ( // SSHSessionID is the UUID of the current session. SSHSessionID = "SSH_SESSION_ID" - - // EnableNonInteractiveSessionRecording can be used to record non-interactive SSH session. - EnableNonInteractiveSessionRecording = "SSH_TELEPORT_RECORD_NON_INTERACTIVE" ) const ( diff --git a/integration/assist/command_test.go b/integration/assist/command_test.go index ed855c6416165..9afd3afcfb724 100644 --- a/integration/assist/command_test.go +++ b/integration/assist/command_test.go @@ -40,7 +40,6 @@ import ( "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/protobuf/types/known/timestamppb" @@ -87,7 +86,6 @@ func TestAssistCommandOpenSSH(t *testing.T) { openAIMock := mockOpenAI(t) rc := setupTeleport(t, testDir, openAIMock.URL) - auth := rc.Process.GetAuthServer() proxyAddr, err := rc.Process.ProxyWebAddr() require.NoError(t, err) @@ -158,24 +156,6 @@ func TestAssistCommandOpenSSH(t *testing.T) { require.NoError(t, err) require.Equal(t, defaults.WebsocketClose, envelope.Type) // Now the execution is finished - - // Waiting for the session recording to be uploaded and available - require.Eventually(t, func() bool { - chunk, err := auth.GetSessionChunk(apidefaults.Namespace, sessionMetadata.Session.ID, 0, 4096) - if err != nil { - if trace.IsNotFound(err) { - return false - } - assert.Fail(t, "error should be nil or not found, is %s", err) - } - assert.NotNil(t, chunk) - return true - }, 10*time.Second, 200*time.Millisecond) - - // Validating the session recording contains the SSH server output - chunk, err := auth.GetSessionChunk(apidefaults.Namespace, sessionMetadata.Session.ID, 0, 4096) - require.NoError(t, err) - require.Equal(t, testCommandOutput, string(chunk)) } // mockOpenAI starts an OpenAI mock server that answers one completion request diff --git a/lib/bpf/bpf.go b/lib/bpf/bpf.go index 3deb36aebe8f9..1b25ee416785b 100644 --- a/lib/bpf/bpf.go +++ b/lib/bpf/bpf.go @@ -304,6 +304,10 @@ func (s *Service) CloseSession(ctx *SessionContext) error { return trace.NewAggregate(errs...) } +func (s *Service) Enabled() bool { + return true +} + // processAccessEvents pulls events off the perf ring buffer, parses them, and emits them to // the audit log. func (s *Service) processAccessEvents() { diff --git a/lib/bpf/common.go b/lib/bpf/common.go index cf9f7b967fadc..0aaf213c85559 100644 --- a/lib/bpf/common.go +++ b/lib/bpf/common.go @@ -40,6 +40,9 @@ type BPF interface { // Close will stop any running BPF programs. Close(restarting bool) error + + // Enabled returns whether enhanced recording is active. + Enabled() bool } // SessionContext contains all the information needed to track and emit @@ -99,6 +102,10 @@ func (s *NOP) CloseSession(_ *SessionContext) error { return nil } +func (s *NOP) Enabled() bool { + return false +} + // IsHostCompatible checks that BPF programs can run on this host. func IsHostCompatible() error { minKernel := semver.New(constants.EnhancedRecordingMinKernel) diff --git a/lib/events/eventstest/mock_recorder_emitter.go b/lib/events/eventstest/mock_recorder_emitter.go index 98f737a59e87a..7b724aa79e098 100644 --- a/lib/events/eventstest/mock_recorder_emitter.go +++ b/lib/events/eventstest/mock_recorder_emitter.go @@ -26,8 +26,9 @@ import ( // MockRecorderEmitter is a recorder and emitter that stores all events. type MockRecorderEmitter struct { - mu sync.RWMutex - events []apievents.AuditEvent + mu sync.RWMutex + events []apievents.AuditEvent + recordedEvents []apievents.PreparedSessionEvent } func (e *MockRecorderEmitter) Write(_ []byte) (int, error) { @@ -46,6 +47,7 @@ func (e *MockRecorderEmitter) EmitAuditEvent(ctx context.Context, event apievent func (e *MockRecorderEmitter) RecordEvent(ctx context.Context, event apievents.PreparedSessionEvent) error { e.mu.Lock() defer e.mu.Unlock() + e.recordedEvents = append(e.recordedEvents, event) e.events = append(e.events, event.GetAuditEvent()) return nil } @@ -70,6 +72,16 @@ func (e *MockRecorderEmitter) Events() []apievents.AuditEvent { return result } +// RecordedEvents returns all the emitted events. +func (e *MockRecorderEmitter) RecordedEvents() []apievents.PreparedSessionEvent { + e.mu.RLock() + defer e.mu.RUnlock() + + result := make([]apievents.PreparedSessionEvent, len(e.recordedEvents)) + copy(result, e.recordedEvents) + return result +} + // Reset clears the emitted events. func (e *MockRecorderEmitter) Reset() { e.mu.Lock() diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 707b50172ca2f..68af896f5dafb 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -409,12 +409,6 @@ type ServerContext struct { killShellr *os.File killShellw *os.File - // multiWriter is used to record non-interactive session output. - // Currently, used by Assist. - multiWriter io.Writer - // recordNonInteractiveSession enables non-interactive session recording. Used by Assist. - recordNonInteractiveSession bool - // ChannelType holds the type of the channel. For example "session" or // "direct-tcpip". Used to create correct subcommand during re-exec. ChannelType string diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 0b7bb2d0ca0e1..b50fb90218e40 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -159,13 +159,7 @@ func (e *localExec) Start(ctx context.Context, channel ssh.Channel) (*ExecResult // Connect stdout and stderr to the channel so the user can interact with the command. e.Cmd.Stderr = channel.Stderr() - - if e.Ctx.recordNonInteractiveSession { - e.Ctx.Tracef("Starting local exec and recording non-interactive session") - e.Cmd.Stdout = io.MultiWriter(e.Ctx.multiWriter, channel) - } else { - e.Cmd.Stdout = channel - } + e.Cmd.Stdout = channel // Copy from the channel (client input) into stdin of the process. inputWriter, err := e.Cmd.StdinPipe() @@ -378,13 +372,7 @@ func (e *remoteExec) Start(ctx context.Context, ch ssh.Channel) (*ExecResult, er } // hook up stdout/err the channel so the user can interact with the command - if e.ctx.recordNonInteractiveSession { - e.ctx.Tracef("Starting remote exec and recording non-interactive session") - e.session.Stdout = io.MultiWriter(e.ctx.multiWriter, ch) - } else { - e.session.Stdout = ch - } - + e.session.Stdout = ch e.session.Stderr = ch.Stderr() inputWriter, err := e.session.StdinPipe() if err != nil { diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 379f826d5bb30..df84992ad45f8 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -1553,8 +1553,7 @@ func (s *Server) handlePuTTYWinadj(ch ssh.Channel, req *ssh.Request) error { // teleportVarPrefixes contains the list of prefixes used by Teleport environment // variables. Matching variables are saved in the session context when forwarding // the calls to a remote SSH server as they can contain Teleport-specific -// information used to process the session properly (e.g. TELEPORT_SESSION or -// SSH_TELEPORT_RECORD_NON_INTERACTIVE) +// information used to process the session properly (e.g. TELEPORT_SESSION) var teleportVarPrefixes = []string{"TELEPORT_", "SSH_TELEPORT_"} func isTeleportEnv(varName string) bool { diff --git a/lib/srv/mock.go b/lib/srv/mock.go index 26b8bbd87911b..c3887076ed5c1 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -156,6 +156,7 @@ type mockServer struct { auth *auth.Server component string clock clockwork.FakeClock + bpf bpf.BPF } // ID is the unique ID of the server. @@ -250,6 +251,10 @@ func (m *mockServer) UseTunnel() bool { // GetBPF returns the BPF service used for enhanced session recording. func (m *mockServer) GetBPF() bpf.BPF { + if m.bpf != nil { + return m.bpf + } + return &bpf.NOP{} } @@ -391,3 +396,23 @@ func (c *mockSSHChannel) SendRequest(name string, wantReply bool, payload []byte func (c *mockSSHChannel) Stderr() io.ReadWriter { return c.stdErr } + +type fakeBPF struct { + bpf bpf.NOP +} + +func (f fakeBPF) OpenSession(ctx *bpf.SessionContext) (uint64, error) { + return f.bpf.OpenSession(ctx) +} + +func (f fakeBPF) CloseSession(ctx *bpf.SessionContext) error { + return f.bpf.CloseSession(ctx) +} + +func (f fakeBPF) Close(restarting bool) error { + return f.bpf.Close(restarting) +} + +func (f fakeBPF) Enabled() bool { + return true +} diff --git a/lib/srv/sess.go b/lib/srv/sess.go index cbb4a6da1cbea..20fe70c20503f 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -365,11 +365,6 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann scx.Tracef("Session found, reusing it %s", sessionID) } - _, found = scx.GetEnv(teleport.EnableNonInteractiveSessionRecording) - if found { - scx.recordNonInteractiveSession = true - } - // This logic allows concurrent request to create a new session // to fail, what is ok because we should never have this condition. sess, _, err := newSession(ctx, sessionID, s, scx, channel) @@ -1012,8 +1007,18 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { sessionJoinEvent.ConnectionMetadata.LocalAddr = ctx.ServerConn.LocalAddr().String() } + var notifyPartyPayload []byte preparedEvent, err := s.Recorder().PrepareSessionEvent(sessionJoinEvent) if err == nil { + // Try marshaling the event prior to emitting it to prevent races since + // the audit/recording machinery might try to set some fields while the + // marshal is underway. + if eventPayload, err := json.Marshal(preparedEvent); err != nil { + s.log.Warnf("Unable to marshal %v: %v.", events.SessionJoinEvent, err) + } else { + notifyPartyPayload = eventPayload + } + if err := s.recordEvent(ctx.srv.Context(), preparedEvent); err != nil { s.log.WithError(err).Warn("Failed to record session join event.") } @@ -1027,12 +1032,15 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { // Notify all members of the party that a new member has joined over the // "x-teleport-event" channel. for _, p := range s.parties { - eventPayload, err := json.Marshal(sessionJoinEvent) - if err != nil { - s.log.Warnf("Unable to marshal %v for %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err) + if len(notifyPartyPayload) == 0 { + s.log.Warnf("No join event to send to %v", p.sconn.RemoteAddr()) continue } - _, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, eventPayload) + + payload := make([]byte, len(notifyPartyPayload)) + copy(payload, notifyPartyPayload) + + _, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, payload) if err != nil { s.log.Warnf("Unable to send %v to %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err) continue @@ -1383,6 +1391,11 @@ func newRecorder(s *session, ctx *ServerContext) (events.SessionPreparerRecorder return events.WithNoOpPreparer(events.NewDiscardRecorder()), nil } + // Don't record non-interactive sessions when enhanced recording is disabled. + if ctx.GetTerm() == nil && !ctx.srv.GetBPF().Enabled() { + return events.WithNoOpPreparer(events.NewDiscardRecorder()), nil + } + rec, err := recorder.New(recorder.Config{ SessionID: s.id, ServerID: s.serverMeta.ServerID, @@ -1413,12 +1426,6 @@ func newRecorder(s *session, ctx *ServerContext) (events.SessionPreparerRecorder } func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *ServerContext) error { - if scx.recordNonInteractiveSession { - // enable recording. - s.io.AddWriter(sessionRecorderID, utils.WriteCloserWithContext(scx.srv.Context(), s.Recorder())) - s.scx.multiWriter = s.io - } - // Emit a session.start event for the exec session. s.emitSessionStartEvent(scx) @@ -1474,10 +1481,6 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve // Process has been placed in a cgroup, continue execution. execRequest.Continue() - if scx.recordNonInteractiveSession { - s.io.On() - } - // Process is running, wait for it to stop. go func() { result = execRequest.Wait() diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index a3017e0daf698..db4bd50c67659 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -20,6 +20,7 @@ import ( "context" "io" "os/user" + "slices" "sync/atomic" "testing" "time" @@ -264,6 +265,7 @@ func TestSession_newRecorder(t *testing.T) { }, sctx: &ServerContext{ SessionRecordingConfig: proxyRecording, + term: &terminal{}, }, errAssertion: require.NoError, recAssertion: isNotSessionWriter, @@ -283,6 +285,7 @@ func TestSession_newRecorder(t *testing.T) { }, sctx: &ServerContext{ SessionRecordingConfig: proxyRecordingSync, + term: &terminal{}, }, errAssertion: require.NoError, recAssertion: isNotSessionWriter, @@ -305,6 +308,7 @@ func TestSession_newRecorder(t *testing.T) { srv: &mockServer{ component: teleport.ComponentNode, }, + term: &terminal{}, Identity: IdentityContext{ AccessChecker: services.NewAccessCheckerWithRoleSet(&services.AccessInfo{ Roles: []string{"dev"}, @@ -361,6 +365,7 @@ func TestSession_newRecorder(t *testing.T) { }, }), }, + term: &terminal{}, }, errAssertion: require.NoError, recAssertion: func(t require.TestingT, i interface{}, _ ...interface{}) { @@ -390,6 +395,7 @@ func TestSession_newRecorder(t *testing.T) { MockRecorderEmitter: &eventstest.MockRecorderEmitter{}, datadir: t.TempDir(), }, + term: &terminal{}, }, errAssertion: require.NoError, recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { @@ -458,12 +464,14 @@ func TestSession_emitAuditEvent(t *testing.T) { }) } -// TestInteractiveSession tests interaction session lifecycles. -// Multiple sessions are opened in parallel tests to test for -// deadlocks between session registry, sessions, and parties. +// TestInteractiveSession tests interactive session lifecycles +// and validates audit events and session recordings are emitted. func TestInteractiveSession(t *testing.T) { + t.Parallel() + srv := newMockServer(t) srv.component = teleport.ComponentNode + t.Cleanup(func() { require.NoError(t, srv.auth.Close()) }) reg, err := NewSessionRegistry(SessionRegistryConfig{ Srv: srv, @@ -472,19 +480,217 @@ func TestInteractiveSession(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { reg.Close() }) - t.Run("Stop", func(t *testing.T) { + // Create a server context with an overridden recording mode + // so that sessions are recorded with the test emitter. + scx := newTestServerContext(t, reg.Srv, nil) + rcfg := types.DefaultSessionRecordingConfig() + rcfg.SetMode(types.RecordAtNodeSync) + scx.SessionRecordingConfig = rcfg + + // Allocate a terminal for the session so that + // events are properly recorded. + terminal, err := newLocalTerminal(scx) + require.NoError(t, err) + scx.term = terminal + + // Open a new session + sshChanOpen := newMockSSHChannel() + go func() { + // Consume stdout sent to the channel + io.ReadAll(sshChanOpen) + }() + require.NoError(t, reg.OpenSession(context.Background(), sshChanOpen, scx)) + require.NotNil(t, scx.session) + + // Simulate changing window size to capture an additional event. + require.NoError(t, reg.NotifyWinChange(context.Background(), rsession.TerminalParams{W: 100, H: 100}, scx)) + + // Stopping the session should trigger the session + // to end and cleanup in the background + scx.session.Stop() + + // Wait for the session to be removed from the registry. + require.Eventually(t, func() bool { + _, found := reg.findSession(scx.session.id) + return !found + }, time.Second*15, time.Millisecond*500) + + // Validate that the expected audit events were emitted. + expectedEvents := []string{events.SessionStartEvent, events.ResizeEvent, events.SessionEndEvent, events.SessionLeaveEvent} + require.Eventually(t, func() bool { + actual := srv.MockRecorderEmitter.Events() + + for _, evt := range expectedEvents { + contains := slices.ContainsFunc(actual, func(event apievents.AuditEvent) bool { + return event.GetType() == evt + }) + if !contains { + return false + } + } + return true + }, 15*time.Second, 500*time.Millisecond) + + // Validate that the expected recording events were emitted. + require.Eventually(t, func() bool { + actual := srv.MockRecorderEmitter.RecordedEvents() + + for _, evt := range expectedEvents { + contains := slices.ContainsFunc(actual, func(event apievents.PreparedSessionEvent) bool { + return event.GetAuditEvent().GetType() == evt + }) + if !contains { + return false + } + } + + return true + }, 15*time.Second, 500*time.Millisecond) +} + +// TestNonInteractiveSession tests non-interactive session lifecycles +// and validates audit events and session recordings are emitted when +// appropriate. +func TestNonInteractiveSession(t *testing.T) { + t.Parallel() + + t.Run("without BPF", func(t *testing.T) { + t.Parallel() + + srv := newMockServer(t) + srv.component = teleport.ComponentNode + t.Cleanup(func() { require.NoError(t, srv.auth.Close()) }) + + reg, err := NewSessionRegistry(SessionRegistryConfig{ + Srv: srv, + SessionTrackerService: srv.auth, + }) + require.NoError(t, err) + t.Cleanup(func() { reg.Close() }) + + // Create a server context with an overridden recording mode + // so that sessions are recorded with the test emitter. + scx := newTestServerContext(t, reg.Srv, nil) + rcfg := types.DefaultSessionRecordingConfig() + rcfg.SetMode(types.RecordAtNodeSync) + scx.SessionRecordingConfig = rcfg + + // Modify the execRequest to actually execute a command. + scx.execRequest = &localExec{Ctx: scx, Command: "true"} + + // Open a new session + sshChanOpen := newMockSSHChannel() + go func() { + // Consume stdout sent to the channel + io.ReadAll(sshChanOpen) + }() + require.NoError(t, reg.OpenExecSession(context.Background(), sshChanOpen, scx)) + require.NotNil(t, scx.session) + + // Wait for the command execution to complete and the session to be terminated. + require.Eventually(t, func() bool { + _, found := reg.findSession(scx.session.id) + return !found + }, time.Second*15, time.Millisecond*500) + + // Verify that all the expected audit events are eventually emitted. + expected := []string{events.SessionStartEvent, events.ExecEvent, events.SessionEndEvent, events.SessionLeaveEvent} + require.Eventually(t, func() bool { + actual := srv.MockRecorderEmitter.Events() + + for _, evt := range expected { + contains := slices.ContainsFunc(actual, func(event apievents.AuditEvent) bool { + return event.GetType() == evt + }) + if !contains { + return false + } + } + + return true + }, 15*time.Second, 500*time.Millisecond) + + // Verify that NO recordings were emitted + require.Empty(t, srv.MockRecorderEmitter.RecordedEvents()) + }) + + t.Run("with BPF", func(t *testing.T) { t.Parallel() - sess, _ := testOpenSession(t, reg, nil) - // Stopping the session should trigger the session - // to end and cleanup in the background - sess.Stop() + srv := newMockServer(t) + srv.component = teleport.ComponentNode + // Modify bpf to "enable" enhanced recording. This should + // trigger recordings to be captured. + srv.bpf = fakeBPF{} + t.Cleanup(func() { require.NoError(t, srv.auth.Close()) }) - sessionClosed := func() bool { - _, found := reg.findSession(sess.id) + reg, err := NewSessionRegistry(SessionRegistryConfig{ + Srv: srv, + SessionTrackerService: srv.auth, + }) + require.NoError(t, err) + t.Cleanup(func() { reg.Close() }) + + // Create a server context with an overridden recording mode + // so that sessions are recorded with the test emitter. + scx := newTestServerContext(t, reg.Srv, nil) + rcfg := types.DefaultSessionRecordingConfig() + rcfg.SetMode(types.RecordAtNodeSync) + scx.SessionRecordingConfig = rcfg + + // Modify the execRequest to actually execute a command. + scx.execRequest = &localExec{Ctx: scx, Command: "true"} + + // Open a new session + sshChanOpen := newMockSSHChannel() + go func() { + // Consume stdout sent to the channel + io.ReadAll(sshChanOpen) + }() + require.NoError(t, reg.OpenExecSession(context.Background(), sshChanOpen, scx)) + require.NotNil(t, scx.session) + + // Wait for the command execution to complete and the session to be terminated. + require.Eventually(t, func() bool { + _, found := reg.findSession(scx.session.id) return !found - } - require.Eventually(t, sessionClosed, time.Second*15, time.Millisecond*500) + }, time.Second*15, time.Millisecond*500) + + // Verify that all the expected audit events are eventually emitted. + expectedEvents := []string{events.SessionStartEvent, events.ExecEvent, events.SessionEndEvent, events.SessionLeaveEvent} + require.Eventually(t, func() bool { + actual := srv.MockRecorderEmitter.Events() + + for _, evt := range expectedEvents { + contains := slices.ContainsFunc(actual, func(event apievents.AuditEvent) bool { + return event.GetType() == evt + }) + if !contains { + return false + } + } + + return true + }, 15*time.Second, 500*time.Millisecond) + + // Validate that the expected recording events were emitted. + require.Eventually(t, func() bool { + actual := srv.MockRecorderEmitter.RecordedEvents() + + for _, evt := range expectedEvents { + if evt == events.ExecEvent { + continue + } + contains := slices.ContainsFunc(actual, func(event apievents.PreparedSessionEvent) bool { + return event.GetAuditEvent().GetType() == evt + }) + if !contains { + return false + } + } + + return true + }, 15*time.Second, 500*time.Millisecond) }) } diff --git a/lib/web/command.go b/lib/web/command.go index 65e6892f815f5..cca1f5a39058f 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -665,9 +665,6 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl defer nc.Close() - // Enable session recording - nc.AddEnv(teleport.EnableNonInteractiveSessionRecording, "true") - // Establish SSH connection to the server. This function will block until // either an error occurs or it completes successfully. if err = nc.RunCommand(ctx, t.interactiveCommand); err != nil {