Skip to content

Commit

Permalink
Disable session recording for non-interactive sessions (#41991) (#42321)
Browse files Browse the repository at this point in the history
These recordings only contain session.start, session.end and
session.leave events, all of which are already included in the
audit log. Removing these recordings should produce no data loss
but will greatly reduce the amount of work performed by the agents,
the auth service, and storage costs.

The only case where non-interactive sessions are still recording is
when BPF is enabled. This is required, for now, because enhanced
session recording can generate more events than the audit log has
traditionally been able to ingest.
  • Loading branch information
rosstimothy authored Jun 4, 2024
1 parent df202f1 commit 2aa9571
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 81 deletions.
3 changes: 0 additions & 3 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ const (

// SSHSessionID is the UUID of the current session.
SSHSessionID = "SSH_SESSION_ID"

// EnableNonInteractiveSessionRecording can be used to record non-interactive SSH session.
EnableNonInteractiveSessionRecording = "SSH_TELEPORT_RECORD_NON_INTERACTIVE"
)

const (
Expand Down
20 changes: 0 additions & 20 deletions integration/assist/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (
"github.com/gorilla/websocket"
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -87,7 +86,6 @@ func TestAssistCommandOpenSSH(t *testing.T) {
openAIMock := mockOpenAI(t)

rc := setupTeleport(t, testDir, openAIMock.URL)
auth := rc.Process.GetAuthServer()
proxyAddr, err := rc.Process.ProxyWebAddr()
require.NoError(t, err)

Expand Down Expand Up @@ -158,24 +156,6 @@ func TestAssistCommandOpenSSH(t *testing.T) {
require.NoError(t, err)
require.Equal(t, defaults.WebsocketClose, envelope.Type)
// Now the execution is finished

// Waiting for the session recording to be uploaded and available
require.Eventually(t, func() bool {
chunk, err := auth.GetSessionChunk(apidefaults.Namespace, sessionMetadata.Session.ID, 0, 4096)
if err != nil {
if trace.IsNotFound(err) {
return false
}
assert.Fail(t, "error should be nil or not found, is %s", err)
}
assert.NotNil(t, chunk)
return true
}, 10*time.Second, 200*time.Millisecond)

// Validating the session recording contains the SSH server output
chunk, err := auth.GetSessionChunk(apidefaults.Namespace, sessionMetadata.Session.ID, 0, 4096)
require.NoError(t, err)
require.Equal(t, testCommandOutput, string(chunk))
}

// mockOpenAI starts an OpenAI mock server that answers one completion request
Expand Down
4 changes: 4 additions & 0 deletions lib/bpf/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ func (s *Service) CloseSession(ctx *SessionContext) error {
return trace.NewAggregate(errs...)
}

func (s *Service) Enabled() bool {
return true
}

// processAccessEvents pulls events off the perf ring buffer, parses them, and emits them to
// the audit log.
func (s *Service) processAccessEvents() {
Expand Down
7 changes: 7 additions & 0 deletions lib/bpf/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ type BPF interface {

// Close will stop any running BPF programs.
Close(restarting bool) error

// Enabled returns whether enhanced recording is active.
Enabled() bool
}

// SessionContext contains all the information needed to track and emit
Expand Down Expand Up @@ -99,6 +102,10 @@ func (s *NOP) CloseSession(_ *SessionContext) error {
return nil
}

func (s *NOP) Enabled() bool {
return false
}

// IsHostCompatible checks that BPF programs can run on this host.
func IsHostCompatible() error {
minKernel := semver.New(constants.EnhancedRecordingMinKernel)
Expand Down
16 changes: 14 additions & 2 deletions lib/events/eventstest/mock_recorder_emitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import (

// MockRecorderEmitter is a recorder and emitter that stores all events.
type MockRecorderEmitter struct {
mu sync.RWMutex
events []apievents.AuditEvent
mu sync.RWMutex
events []apievents.AuditEvent
recordedEvents []apievents.PreparedSessionEvent
}

func (e *MockRecorderEmitter) Write(_ []byte) (int, error) {
Expand All @@ -46,6 +47,7 @@ func (e *MockRecorderEmitter) EmitAuditEvent(ctx context.Context, event apievent
func (e *MockRecorderEmitter) RecordEvent(ctx context.Context, event apievents.PreparedSessionEvent) error {
e.mu.Lock()
defer e.mu.Unlock()
e.recordedEvents = append(e.recordedEvents, event)
e.events = append(e.events, event.GetAuditEvent())
return nil
}
Expand All @@ -70,6 +72,16 @@ func (e *MockRecorderEmitter) Events() []apievents.AuditEvent {
return result
}

// RecordedEvents returns all the emitted events.
func (e *MockRecorderEmitter) RecordedEvents() []apievents.PreparedSessionEvent {
e.mu.RLock()
defer e.mu.RUnlock()

result := make([]apievents.PreparedSessionEvent, len(e.recordedEvents))
copy(result, e.recordedEvents)
return result
}

// Reset clears the emitted events.
func (e *MockRecorderEmitter) Reset() {
e.mu.Lock()
Expand Down
6 changes: 0 additions & 6 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,6 @@ type ServerContext struct {
killShellr *os.File
killShellw *os.File

// multiWriter is used to record non-interactive session output.
// Currently, used by Assist.
multiWriter io.Writer
// recordNonInteractiveSession enables non-interactive session recording. Used by Assist.
recordNonInteractiveSession bool

// ChannelType holds the type of the channel. For example "session" or
// "direct-tcpip". Used to create correct subcommand during re-exec.
ChannelType string
Expand Down
16 changes: 2 additions & 14 deletions lib/srv/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,7 @@ func (e *localExec) Start(ctx context.Context, channel ssh.Channel) (*ExecResult

// Connect stdout and stderr to the channel so the user can interact with the command.
e.Cmd.Stderr = channel.Stderr()

if e.Ctx.recordNonInteractiveSession {
e.Ctx.Tracef("Starting local exec and recording non-interactive session")
e.Cmd.Stdout = io.MultiWriter(e.Ctx.multiWriter, channel)
} else {
e.Cmd.Stdout = channel
}
e.Cmd.Stdout = channel

// Copy from the channel (client input) into stdin of the process.
inputWriter, err := e.Cmd.StdinPipe()
Expand Down Expand Up @@ -378,13 +372,7 @@ func (e *remoteExec) Start(ctx context.Context, ch ssh.Channel) (*ExecResult, er
}

// hook up stdout/err the channel so the user can interact with the command
if e.ctx.recordNonInteractiveSession {
e.ctx.Tracef("Starting remote exec and recording non-interactive session")
e.session.Stdout = io.MultiWriter(e.ctx.multiWriter, ch)
} else {
e.session.Stdout = ch
}

e.session.Stdout = ch
e.session.Stderr = ch.Stderr()
inputWriter, err := e.session.StdinPipe()
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1553,8 +1553,7 @@ func (s *Server) handlePuTTYWinadj(ch ssh.Channel, req *ssh.Request) error {
// teleportVarPrefixes contains the list of prefixes used by Teleport environment
// variables. Matching variables are saved in the session context when forwarding
// the calls to a remote SSH server as they can contain Teleport-specific
// information used to process the session properly (e.g. TELEPORT_SESSION or
// SSH_TELEPORT_RECORD_NON_INTERACTIVE)
// information used to process the session properly (e.g. TELEPORT_SESSION)
var teleportVarPrefixes = []string{"TELEPORT_", "SSH_TELEPORT_"}

func isTeleportEnv(varName string) bool {
Expand Down
25 changes: 25 additions & 0 deletions lib/srv/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type mockServer struct {
auth *auth.Server
component string
clock clockwork.FakeClock
bpf bpf.BPF
}

// ID is the unique ID of the server.
Expand Down Expand Up @@ -250,6 +251,10 @@ func (m *mockServer) UseTunnel() bool {

// GetBPF returns the BPF service used for enhanced session recording.
func (m *mockServer) GetBPF() bpf.BPF {
if m.bpf != nil {
return m.bpf
}

return &bpf.NOP{}
}

Expand Down Expand Up @@ -391,3 +396,23 @@ func (c *mockSSHChannel) SendRequest(name string, wantReply bool, payload []byte
func (c *mockSSHChannel) Stderr() io.ReadWriter {
return c.stdErr
}

type fakeBPF struct {
bpf bpf.NOP
}

func (f fakeBPF) OpenSession(ctx *bpf.SessionContext) (uint64, error) {
return f.bpf.OpenSession(ctx)
}

func (f fakeBPF) CloseSession(ctx *bpf.SessionContext) error {
return f.bpf.CloseSession(ctx)
}

func (f fakeBPF) Close(restarting bool) error {
return f.bpf.Close(restarting)
}

func (f fakeBPF) Enabled() bool {
return true
}
41 changes: 22 additions & 19 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,6 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann
scx.Tracef("Session found, reusing it %s", sessionID)
}

_, found = scx.GetEnv(teleport.EnableNonInteractiveSessionRecording)
if found {
scx.recordNonInteractiveSession = true
}

// This logic allows concurrent request to create a new session
// to fail, what is ok because we should never have this condition.
sess, _, err := newSession(ctx, sessionID, s, scx, channel)
Expand Down Expand Up @@ -1012,8 +1007,18 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {
sessionJoinEvent.ConnectionMetadata.LocalAddr = ctx.ServerConn.LocalAddr().String()
}

var notifyPartyPayload []byte
preparedEvent, err := s.Recorder().PrepareSessionEvent(sessionJoinEvent)
if err == nil {
// Try marshaling the event prior to emitting it to prevent races since
// the audit/recording machinery might try to set some fields while the
// marshal is underway.
if eventPayload, err := json.Marshal(preparedEvent); err != nil {
s.log.Warnf("Unable to marshal %v: %v.", events.SessionJoinEvent, err)
} else {
notifyPartyPayload = eventPayload
}

if err := s.recordEvent(ctx.srv.Context(), preparedEvent); err != nil {
s.log.WithError(err).Warn("Failed to record session join event.")
}
Expand All @@ -1027,12 +1032,15 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {
// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.parties {
eventPayload, err := json.Marshal(sessionJoinEvent)
if err != nil {
s.log.Warnf("Unable to marshal %v for %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
if len(notifyPartyPayload) == 0 {
s.log.Warnf("No join event to send to %v", p.sconn.RemoteAddr())
continue
}
_, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, eventPayload)

payload := make([]byte, len(notifyPartyPayload))
copy(payload, notifyPartyPayload)

_, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, payload)
if err != nil {
s.log.Warnf("Unable to send %v to %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
continue
Expand Down Expand Up @@ -1383,6 +1391,11 @@ func newRecorder(s *session, ctx *ServerContext) (events.SessionPreparerRecorder
return events.WithNoOpPreparer(events.NewDiscardRecorder()), nil
}

// Don't record non-interactive sessions when enhanced recording is disabled.
if ctx.GetTerm() == nil && !ctx.srv.GetBPF().Enabled() {
return events.WithNoOpPreparer(events.NewDiscardRecorder()), nil
}

rec, err := recorder.New(recorder.Config{
SessionID: s.id,
ServerID: s.serverMeta.ServerID,
Expand Down Expand Up @@ -1413,12 +1426,6 @@ func newRecorder(s *session, ctx *ServerContext) (events.SessionPreparerRecorder
}

func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *ServerContext) error {
if scx.recordNonInteractiveSession {
// enable recording.
s.io.AddWriter(sessionRecorderID, utils.WriteCloserWithContext(scx.srv.Context(), s.Recorder()))
s.scx.multiWriter = s.io
}

// Emit a session.start event for the exec session.
s.emitSessionStartEvent(scx)

Expand Down Expand Up @@ -1474,10 +1481,6 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
// Process has been placed in a cgroup, continue execution.
execRequest.Continue()

if scx.recordNonInteractiveSession {
s.io.On()
}

// Process is running, wait for it to stop.
go func() {
result = execRequest.Wait()
Expand Down
Loading

0 comments on commit 2aa9571

Please sign in to comment.