Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GitHub Proxy: complete audit event flow and add an enterprise guard on github integration #51049

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/auth/integration/integrationv1/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ import (
integrationpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/modules"
)

func TestExportIntegrationCertAuthorities(t *testing.T) {
t.Parallel()
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})

ca := newCertAuthority(t, types.HostCA, "test-cluster")
ctx, localClient, resourceSvc := initSvc(t, ca, ca.GetClusterName(), "127.0.0.1")
Expand Down
3 changes: 2 additions & 1 deletion lib/auth/integration/integrationv1/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/modules"
)

func TestGenerateGitHubUserCert(t *testing.T) {
t.Parallel()
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})

ca := newCertAuthority(t, types.HostCA, "test-cluster")
ctx, _, resourceSvc := initSvc(t, ca, ca.GetClusterName(), "127.0.0.1.nip.io")
Expand Down
5 changes: 4 additions & 1 deletion lib/auth/integration/integrationv1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
)

Expand Down Expand Up @@ -223,7 +224,9 @@ func (s *Service) CreateIntegration(ctx context.Context, req *integrationpb.Crea

switch req.Integration.GetSubKind() {
case types.IntegrationSubKindGitHub:
// TODO(greedy52) add entitlement check
if modules.GetModules().BuildType() != modules.BuildEnterprise {
return nil, trace.AccessDenied("GitHub integration requires a Teleport Enterprise license")
}
if err := s.createGitHubCredentials(ctx, req.Integration); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
4 changes: 3 additions & 1 deletion lib/auth/integration/integrationv1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ import (
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/tlsca"
)

func TestIntegrationCRUD(t *testing.T) {
t.Parallel()
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})

clusterName := "test-cluster"
proxyPublicAddr := "127.0.0.1.nip.io"

Expand Down
2 changes: 2 additions & 0 deletions lib/events/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) {
e = &events.AccessGraphSettingsUpdate{}
case DatabaseSessionSpannerRPCEvent:
e = &events.SpannerRPC{}
case GitCommandEvent:
e = &events.GitCommand{}
case UnknownEvent:
e = &events.Unknown{}

Expand Down
4 changes: 3 additions & 1 deletion lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,9 @@ func (c *ServerContext) reportStats(conn utils.Stater) {
// Never emit session data events for the proxy or from a Teleport node if
// sessions are being recorded at the proxy (this would result in double
// events).
if c.GetServer().Component() == teleport.ComponentProxy {
// Do not emit session data for git commands as they have their own events.
if c.GetServer().Component() == teleport.ComponentProxy ||
c.GetServer().Component() == teleport.ComponentForwardingGit {
return
}
if services.IsRecordAtProxy(c.SessionRecordingConfig.GetMode()) &&
Expand Down
90 changes: 84 additions & 6 deletions lib/srv/git/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"io"
"log/slog"
"net"
"strconv"

"github.com/google/uuid"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -323,12 +324,19 @@ func (s *ForwardServer) onConnection(ctx context.Context, ccx *sshutils.Connecti

s.logger.Log(ctx, logutils.TraceLevel, "New connection accepted")
ccx.AddCloser(serverCtx)
return ctx, nil
return context.WithValue(ctx, serverContextKey{}, serverCtx), nil
}

func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) {
s.logger.DebugContext(ctx, "Handling channel request", "channel", nch.ChannelType())

serverCtx, ok := ctx.Value(serverContextKey{}).(*srv.ServerContext)
if !ok {
// This should not happen. Double check just in case.
s.reply.RejectChannel(ctx, nch, ssh.ResourceShortage, "server context not found")
return
}

// Only expecting a session to execute a command.
if nch.ChannelType() != teleport.ChanSession {
s.reply.RejectUnknownChannel(ctx, nch)
Expand All @@ -353,7 +361,7 @@ func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionC
}
defer ch.Close()

sctx := newSessionContext(ch, remoteSession)
sctx := newSessionContext(serverCtx, ch, remoteSession)
for {
select {
case req := <-in:
Expand Down Expand Up @@ -382,13 +390,16 @@ func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionC
}

type sessionContext struct {
*srv.ServerContext

channel ssh.Channel
remoteSession *tracessh.Session
waitExec chan error
}

func newSessionContext(ch ssh.Channel, remoteSession *tracessh.Session) *sessionContext {
func newSessionContext(serverCtx *srv.ServerContext, ch ssh.Channel, remoteSession *tracessh.Session) *sessionContext {
return &sessionContext{
ServerContext: serverCtx,
channel: ch,
remoteSession: remoteSession,
waitExec: make(chan error, 1),
Expand All @@ -415,13 +426,27 @@ func (s *ForwardServer) dispatch(ctx context.Context, sctx *sessionContext, req
}

// handleExec proxies the Git command between client and the target server.
func (s *ForwardServer) handleExec(ctx context.Context, sctx *sessionContext, req *ssh.Request) error {
func (s *ForwardServer) handleExec(ctx context.Context, sctx *sessionContext, req *ssh.Request) (err error) {
var r sshutils.ExecReq
defer func() {
if err != nil {
s.emitEvent(s.makeGitCommandEvent(sctx, r.Command, err))
}
}()

if err := ssh.Unmarshal(req.Payload, &r); err != nil {
return trace.Wrap(err, "failed to unmarshal exec request")
}

// TODO(greedy52) enable command recorder for audit log
command, err := ParseSSHCommand(r.Command)
if err != nil {
return trace.Wrap(err)
}
if err := checkSSHCommand(s.cfg.TargetServer, command); err != nil {
return trace.Wrap(err)
}
recorder := NewCommandRecorder(ctx, *command)

sctx.remoteSession.Stdout = sctx.channel
sctx.remoteSession.Stderr = sctx.channel.Stderr()
remoteStdin, err := sctx.remoteSession.StdinPipe()
Expand All @@ -430,7 +455,7 @@ func (s *ForwardServer) handleExec(ctx context.Context, sctx *sessionContext, re
}
go func() {
defer remoteStdin.Close()
if _, err := io.Copy(remoteStdin, sctx.channel); err != nil {
if _, err := io.Copy(io.MultiWriter(remoteStdin, recorder), sctx.channel); err != nil {
s.logger.WarnContext(ctx, "Failed to copy git command stdin", "error", err)
}
}()
Expand All @@ -442,10 +467,61 @@ func (s *ForwardServer) handleExec(ctx context.Context, sctx *sessionContext, re
go func() {
execErr := sctx.remoteSession.Wait()
sctx.waitExec <- execErr
s.emitEvent(s.makeGitCommandEventWithExecResult(sctx, recorder, execErr))
}()
return nil
}

func (s *ForwardServer) emitEvent(event apievents.AuditEvent) {
if err := s.cfg.Emitter.EmitAuditEvent(s.cfg.ParentContext, event); err != nil {
s.logger.WarnContext(s.cfg.ParentContext, "Failed to emit event",
"error", err,
"event_type", event.GetType(),
"event_code", event.GetCode(),
)
}
}

func (s *ForwardServer) makeGitCommandEvent(sctx *sessionContext, command string, err error) *apievents.GitCommand {
event := &apievents.GitCommand{
Metadata: apievents.Metadata{
Type: events.GitCommandEvent,
Code: events.GitCommandCode,
},
UserMetadata: sctx.Identity.GetUserMetadata(),
SessionMetadata: sctx.GetSessionMetadata(),
CommandMetadata: apievents.CommandMetadata{
Command: command,
},
ConnectionMetadata: apievents.ConnectionMetadata{
RemoteAddr: sctx.ServerConn.RemoteAddr().String(),
LocalAddr: sctx.ServerConn.LocalAddr().String(),
},
ServerMetadata: s.TargetMetadata(),
}
if err != nil {
event.Metadata.Code = events.GitCommandFailureCode
event.Error = err.Error()
}
return event
}

func (s *ForwardServer) makeGitCommandEventWithExecResult(sctx *sessionContext, recorder CommandRecorder, execErr error) *apievents.GitCommand {
event := s.makeGitCommandEvent(sctx, recorder.GetCommand().SSHCommand, execErr)

event.ExitCode = strconv.Itoa(sshutils.ExitCodeFromExecError(execErr))
event.Path = string(recorder.GetCommand().Repository)
event.Service = recorder.GetCommand().Service

actions, err := recorder.GetActions()
if err != nil {
s.logger.WarnContext(s.cfg.ParentContext, "Failed to get actions from Git command recorder. No actions will be recorded in the event.", "error", err)
} else {
event.Actions = actions
}
return event
}

// handleEnv sets env on the target server.
func (s *ForwardServer) handleEnv(ctx context.Context, sctx *sessionContext, req *ssh.Request) error {
var e sshutils.EnvReqParams
Expand Down Expand Up @@ -594,6 +670,8 @@ func (s *ForwardServer) GetHostSudoers() srv.HostSudoers {
return nil
}

type serverContextKey struct{}

const (
gitUser = "git"
)
74 changes: 73 additions & 1 deletion lib/srv/git/forward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/constants"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/cryptosuites"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
Expand All @@ -66,6 +69,7 @@ func TestForwardServer(t *testing.T) {
verifyRemoteHost ssh.HostKeyCallback
wantNewClientError bool
verifyWithClient func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService)
verifyEvent func(t *testing.T, event apievents.AuditEvent)
}{
{
name: "success",
Expand All @@ -85,6 +89,45 @@ func TestForwardServer(t *testing.T) {
require.NoError(t, err)
require.Equal(t, gitCommand, m.receivedExec.Command)
},
verifyEvent: func(t *testing.T, event apievents.AuditEvent) {
gitEvent, ok := event.(*apievents.GitCommand)
require.True(t, ok)
assert.Equal(t, libevents.GitCommandEvent, gitEvent.Metadata.Type)
assert.Equal(t, libevents.GitCommandCode, gitEvent.Metadata.Code)
assert.Equal(t, "alice", gitEvent.User)
assert.Equal(t, "0", gitEvent.CommandMetadata.ExitCode)
assert.Equal(t, "git-upload-pack", gitEvent.Service)
assert.Equal(t, "org/my-repo.git", gitEvent.Path)
},
},
{
name: "command failure",
allowedGitHubOrg: "*",
clientLogin: "git",
verifyRemoteHost: ssh.InsecureIgnoreHostKey(),
wantNewClientError: false,
verifyWithClient: func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService) {
m.exitCode = 1

session, err := client.NewSession(ctx)
require.NoError(t, err)
defer session.Close()

gitCommand := "git-receive-pack 'org/my-repo.git'"
session.Stderr = io.Discard
session.Stdout = io.Discard
require.Error(t, session.Run(ctx, gitCommand))
},
verifyEvent: func(t *testing.T, event apievents.AuditEvent) {
gitEvent, ok := event.(*apievents.GitCommand)
require.True(t, ok)
assert.Equal(t, libevents.GitCommandEvent, gitEvent.Metadata.Type)
assert.Equal(t, libevents.GitCommandFailureCode, gitEvent.Metadata.Code)
assert.Equal(t, "alice", gitEvent.User)
assert.Equal(t, "1", gitEvent.CommandMetadata.ExitCode)
assert.Equal(t, "git-receive-pack", gitEvent.Service)
assert.Equal(t, "org/my-repo.git", gitEvent.Path)
},
},
{
name: "failed RBAC",
Expand Down Expand Up @@ -124,6 +167,31 @@ func TestForwardServer(t *testing.T) {
require.Contains(t, err.Error(), "unknown channel type")
},
},
{
name: "org mismatch",
allowedGitHubOrg: "*",
clientLogin: "git",
verifyRemoteHost: ssh.InsecureIgnoreHostKey(),
wantNewClientError: false,
verifyWithClient: func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService) {
session, err := client.NewSession(ctx)
require.NoError(t, err)
defer session.Close()

gitCommand := "git-upload-pack 'some-other-org/my-repo.git'"
session.Stderr = io.Discard
session.Stdout = io.Discard
require.Error(t, session.Run(ctx, gitCommand))
},
verifyEvent: func(t *testing.T, event apievents.AuditEvent) {
gitEvent, ok := event.(*apievents.GitCommand)
require.True(t, ok)
assert.Equal(t, libevents.GitCommandEvent, gitEvent.Metadata.Type)
assert.Equal(t, libevents.GitCommandFailureCode, gitEvent.Metadata.Code)
assert.Equal(t, "alice", gitEvent.User)
assert.Contains(t, gitEvent.Error, "expect organization")
},
},
}

for _, test := range tests {
Expand Down Expand Up @@ -189,6 +257,9 @@ func TestForwardServer(t *testing.T) {
defer client.Close()

test.verifyWithClient(t, ctx, client, mockGitService)
if test.verifyEvent != nil {
test.verifyEvent(t, mockEmitter.LastEvent())
}
})
}

Expand Down Expand Up @@ -248,6 +319,7 @@ type mockGitHostingService struct {
*sshutils.Server
*sshutils.Reply
receivedExec sshutils.ExecReq
exitCode int
}

func newMockGitHostingService(t *testing.T, caSigner ssh.Signer) *mockGitHostingService {
Expand Down Expand Up @@ -300,7 +372,7 @@ func (m *mockGitHostingService) HandleNewChan(ctx context.Context, ccx *sshutils
m.ReplyRequest(ctx, req, true, nil)
}
slog.DebugContext(ctx, "mock git service receives new exec request", "req", m.receivedExec)
m.SendExitStatus(ctx, ch, 0)
m.SendExitStatus(ctx, ch, m.exitCode)
return

case <-ctx.Done():
Expand Down
Loading