From 723f751774b498393006ab0c9cce1da01e1e1fb1 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:44:59 -0500 Subject: [PATCH] [v17] Sanitize SSH server hostnames (#49091) * Sanitize SSH server hostnames Prevents any invalid and malicious hostnames, but replacing them with known valid data already associated with the host. This was chosen instead of rejecting to persist the server resource in an attempt to continue providing access to the host in order to remedy the invalid hostname. Any servers that represent a Teleport ssh_service with an invalid hostname will be replaced by the host UUID. Any static OpenSSH servers will have invalid hostnames replaced with the address. This will continue to allow the hosts to be dialable. In order to make these hosts discoverable, the invalid hostname will be set in the "teleport.internal/invalid-hostname" label. Updates https://github.com/gravitational/teleport-private/issues/1676. * add and use internal update node method * add test coverage for UpdateNode --- lib/auth/auth.go | 83 +++++++++++++++++ lib/auth/auth_test.go | 138 ++++++++++++++++++++++++++-- lib/auth/grpcserver_test.go | 4 +- lib/services/local/presence.go | 28 ++++++ lib/services/local/presence_test.go | 25 +++++ lib/services/presence.go | 1 + lib/services/suite/suite.go | 3 + lib/web/apiserver_test.go | 5 +- 8 files changed, 273 insertions(+), 14 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index dcbcf632b598d..356011ec69c47 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -39,7 +39,9 @@ import ( "log/slog" "math/big" insecurerand "math/rand" + "net" "os" + "regexp" "slices" "sort" "strconv" @@ -1470,6 +1472,28 @@ func (a *Server) runPeriodicOperations() { if services.NodeHasMissedKeepAlives(srv) { missedKeepAliveCount++ } + + // TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then. + if !validServerHostname(srv.GetHostname()) { + if srv.GetSubKind() != types.SubKindOpenSSHNode { + return false, nil + } + + logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname()) + + logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname") + // Any existing static hosts will not have their + // hostname sanitized since they don't heartbeat. + if err := sanitizeHostname(srv); err != nil { + logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err) + return false, nil + } + + if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { + logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err) + } + } + return false, nil }, req, @@ -5618,9 +5642,68 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error { return nil } +const ( + serverHostnameMaxLen = 256 + serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$` + replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname" +) + +var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern) + +// validServerHostname returns false if the hostname is longer than 256 characters or +// does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also +// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol. +func validServerHostname(hostname string) bool { + return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname) +} + +func sanitizeHostname(server types.Server) error { + invalidHostname := server.GetHostname() + + replacedHostname := server.GetName() + if server.GetSubKind() == types.SubKindOpenSSHNode { + host, _, err := net.SplitHostPort(server.GetAddr()) + if err != nil || !validServerHostname(host) { + id, err := uuid.NewRandom() + if err != nil { + return trace.Wrap(err) + } + + host = id.String() + } + + replacedHostname = host + } + + switch s := server.(type) { + case *types.ServerV2: + s.Spec.Hostname = replacedHostname + + if s.Metadata.Labels == nil { + s.Metadata.Labels = map[string]string{} + } + + s.Metadata.Labels[replacedHostnameLabel] = invalidHostname + default: + return trace.BadParameter("invalid server provided") + } + + return nil +} + // UpsertNode implements [services.Presence] by delegating to [Server.Services] // and potentially emitting a [usagereporter] event. func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) { + if !validServerHostname(server.GetHostname()) { + a.logger.DebugContext(a.closeCtx, "sanitizing invalid server hostname", + "server", server.GetName(), + "hostname", server.GetHostname(), + ) + if err := sanitizeHostname(server); err != nil { + return nil, trace.Wrap(err) + } + } + lease, err := a.Services.UpsertNode(ctx, server) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 53150ac3e176b..e4978e32e358a 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -19,6 +19,7 @@ package auth import ( + "cmp" "context" "crypto/rand" "crypto/x509" @@ -34,7 +35,7 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" + gocmp "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/gravitational/license" @@ -307,7 +308,7 @@ func TestSessions(t *testing.T) { require.NoError(t, err) assert.Empty(t, out.GetSSHPriv()) assert.Empty(t, out.GetTLSPriv()) - assert.Empty(t, cmp.Diff(ws, out, + assert.Empty(t, gocmp.Diff(ws, out, cmpopts.IgnoreFields(types.Metadata{}, "Revision"), cmpopts.IgnoreFields(types.WebSessionSpecV2{}, "Priv", "TLSPriv"))) @@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) { AssetTag: test.opts.DeviceExtensions.AssetTag, CredentialId: test.opts.DeviceExtensions.CredentialID, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := gocmp.Diff(want, got); diff != "" { t.Errorf("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n%s", diff) } } @@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) { // Assert TLS extensions. _, newIdentity := parseX509PEMAndIdentity(t, newSession.GetTLSCert()) wantExts := tlsca.DeviceExtensions(*deviceExts) - if diff := cmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" { + if diff := gocmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" { t.Errorf("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n%s", diff) } // Assert SSH extensions. - if diff := cmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" { + if diff := gocmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" { t.Errorf("newSession.Pub DeviceExtensions mismatch (-want +got)\n%s", diff) } }) @@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) { // Validate audit event. lastEvent := p.mockEmitter.LastEvent() require.IsType(t, &apievents.CertificateCreate{}, lastEvent) - require.Empty(t, cmp.Diff( + require.Empty(t, gocmp.Diff( &apievents.CertificateCreate{ Metadata: apievents.Metadata{ Type: events.CertificateCreateEvent, @@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD } // Ignore LastUsed and SignatureCounter? - var opts []cmp.Option + var opts []gocmp.Option if ignoreUpdateAndCounter { - opts = append(opts, cmp.FilterPath(func(path cmp.Path) bool { + opts = append(opts, gocmp.FilterPath(func(path gocmp.Path) bool { p := path.String() return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter" - }, cmp.Ignore())) + }, gocmp.Ignore())) } - if diff := cmp.Diff(want, got, opts...); diff != "" { + if diff := gocmp.Diff(want, got, opts...); diff != "" { t.Errorf("compareDevices mismatch (-want +got):\n%s", diff) } } @@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest return ¬ification } + +// TestServerHostnameSanitization tests that persisting servers with +// "invalid" hostnames results in the hostname being sanitized and the +// illegal name being placed in a label. +func TestServerHostnameSanitization(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv, err := NewTestAuthServer(TestAuthServerConfig{Dir: t.TempDir()}) + require.NoError(t, err) + + cases := []struct { + name string + hostname string + addr string + invalidHostname bool + invalidAddr bool + }{ + { + name: "valid dns hostname", + hostname: "llama.example.com", + }, + { + name: "valid friendly hostname", + hostname: "llama", + }, + { + name: "uuid hostname", + hostname: uuid.NewString(), + }, + { + name: "uuid dns hostname", + hostname: uuid.NewString() + ".example.com", + }, + { + name: "empty hostname", + hostname: "", + invalidHostname: true, + }, + { + name: "exceptionally long hostname", + hostname: strings.Repeat("a", serverHostnameMaxLen*2), + invalidHostname: true, + }, + { + name: "invalid dns hostname", + hostname: "llama..example.com", + invalidHostname: true, + }, + { + name: "spaces in hostname", + hostname: "the quick brown fox jumps over the lazy dog", + invalidHostname: true, + }, + { + name: "invalid addr", + hostname: "..", + addr: "..:2345", + invalidHostname: true, + invalidAddr: true, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + for _, subKind := range []string{types.KindNode, types.SubKindOpenSSHNode} { + t.Run(subKind, func(t *testing.T) { + server := &types.ServerV2{ + Kind: types.KindNode, + SubKind: subKind, + Metadata: types.Metadata{ + Name: uuid.NewString(), + }, + Spec: types.ServerSpecV2{ + Hostname: test.hostname, + Addr: cmp.Or(test.addr, "abcd:1234"), + }, + } + if subKind == types.KindNode { + server.SubKind = "" + } + + _, err = srv.AuthServer.UpsertNode(ctx, server) + require.NoError(t, err) + + replacedValue, _ := server.GetLabel("teleport.internal/invalid-hostname") + if !test.invalidHostname { + assert.Equal(t, test.hostname, server.GetHostname()) + assert.Empty(t, replacedValue) + return + } + + assert.Equal(t, test.hostname, replacedValue) + switch subKind { + case types.SubKindOpenSSHNode: + host, _, err := net.SplitHostPort(server.GetAddr()) + assert.NoError(t, err) + if !test.invalidAddr { + // If the address is valid, then the hostname should be set + // to the host of the addr field. + assert.Equal(t, host, server.GetHostname()) + } else { + // If the address is not valid, then the hostname should be + // set to a UUID. + assert.NotEqual(t, host, server.GetHostname()) + assert.NotEqual(t, server.GetName(), server.GetHostname()) + + _, err := uuid.Parse(server.GetHostname()) + require.NoError(t, err) + } + default: + assert.Equal(t, server.GetName(), server.GetHostname()) + } + }) + } + }) + } +} diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index aef8897492d80..2890b255f7114 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -2955,9 +2955,9 @@ func TestNodesCRUD(t *testing.T) { require.NoError(t, err) // node1 and node2 will be added to default namespace - node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{}, nil) + node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{Hostname: "node1"}, nil) require.NoError(t, err) - node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{}, nil) + node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{Hostname: "node2"}, nil) require.NoError(t, err) t.Run("CreateNode", func(t *testing.T) { diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index e730a9b7daff0..28525d1cde113 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -371,6 +371,34 @@ func (s *PresenceService) UpsertNode(ctx context.Context, server types.Server) ( }, nil } +// UpdateNode conditionally updates the provided server. +func (s *PresenceService) UpdateNode(ctx context.Context, server types.Server) (types.Server, error) { + if server.GetNamespace() == "" { + server.SetNamespace(apidefaults.Namespace) + } + + if n := server.GetNamespace(); n != apidefaults.Namespace { + return nil, trace.BadParameter("cannot place node in namespace %q, custom namespaces are deprecated", n) + } + rev := server.GetRevision() + value, err := services.MarshalServer(server) + if err != nil { + return nil, trace.Wrap(err) + } + lease, err := s.ConditionalUpdate(ctx, backend.Item{ + Key: backend.NewKey(nodesPrefix, server.GetNamespace(), server.GetName()), + Value: value, + Expires: server.Expiry(), + Revision: rev, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + server.SetRevision(lease.Revision) + return server, nil +} + // GetAuthServers returns a list of registered servers func (s *PresenceService) GetAuthServers() ([]types.Server, error) { return s.getServers(context.TODO(), types.KindAuthServer, authServersPrefix) diff --git a/lib/services/local/presence_test.go b/lib/services/local/presence_test.go index a49fcafc65100..7ed3d8ded7112 100644 --- a/lib/services/local/presence_test.go +++ b/lib/services/local/presence_test.go @@ -262,6 +262,31 @@ func TestNodeCRUD(t *testing.T) { require.NoError(t, err) }) + t.Run("UpdateNode", func(t *testing.T) { + node1, err = presence.GetNode(ctx, apidefaults.Namespace, node1.GetName()) + require.NoError(t, err) + node1.SetAddr("1.2.3.4:8080") + + node2, err = presence.GetNode(ctx, apidefaults.Namespace, node2.GetName()) + require.NoError(t, err) + + node1, err = presence.UpdateNode(ctx, node1) + require.NoError(t, err) + require.Equal(t, "1.2.3.4:8080", node1.GetAddr()) + + rev := node2.GetRevision() + node2.SetAddr("1.2.3.4:9090") + node2.SetRevision(node1.GetRevision()) + + _, err = presence.UpdateNode(ctx, node2) + require.True(t, trace.IsCompareFailed(err)) + node2.SetRevision(rev) + + node2, err = presence.UpdateNode(ctx, node2) + require.NoError(t, err) + require.Equal(t, "1.2.3.4:9090", node2.GetAddr()) + }) + // Run NodeGetters in nested subtests to allow parallelization. t.Run("NodeGetters", func(t *testing.T) { t.Run("GetNodes", func(t *testing.T) { diff --git a/lib/services/presence.go b/lib/services/presence.go index d69bfb8eed6d0..12832a6e8c101 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -212,4 +212,5 @@ type PresenceInternal interface { UpsertHostUserInteractionTime(ctx context.Context, name string, loginTime time.Time) error GetHostUserInteractionTime(ctx context.Context, name string) (time.Time, error) UpsertReverseTunnelV2(ctx context.Context, tunnel types.ReverseTunnel) (types.ReverseTunnel, error) + UpdateNode(ctx context.Context, server types.Server) (types.Server, error) } diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 880005063464d..ad2cb4695b3f2 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -488,6 +488,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { require.Empty(t, out) srv := NewServer(types.KindNode, "srv1", "127.0.0.1:2022", apidefaults.Namespace) + srv.Spec.Hostname = "llama" _, err = s.PresenceS.UpsertNode(ctx, srv) require.NoError(t, err) @@ -513,6 +514,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { require.Empty(t, out) proxy := NewServer(types.KindProxy, "proxy1", "127.0.0.1:2023", apidefaults.Namespace) + proxy.Spec.Hostname = "proxy.llama" require.NoError(t, s.PresenceS.UpsertProxy(ctx, proxy)) out, err = s.PresenceS.GetProxies() @@ -533,6 +535,7 @@ func (s *ServicesTestSuite) ServerCRUD(t *testing.T) { require.Empty(t, out) auth := NewServer(types.KindAuthServer, "auth1", "127.0.0.1:2025", apidefaults.Namespace) + auth.Spec.Hostname = "auth.llama" require.NoError(t, s.PresenceS.UpsertAuthServer(ctx, auth)) out, err = s.PresenceS.GetAuthServers() diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 40e2062c91a5d..ff0f12fdc20cb 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1150,7 +1150,7 @@ func TestClusterNodesGet(t *testing.T) { server1 := servers[0] // Add another node. - server2, err := types.NewServerWithLabels("server2", types.KindNode, types.ServerSpecV2{}, map[string]string{"test-field": "test-value"}) + server2, err := types.NewServerWithLabels("server2", types.KindNode, types.ServerSpecV2{Hostname: "server2"}, map[string]string{"test-field": "test-value"}) require.NoError(t, err) _, err = env.server.Auth().UpsertNode(context.Background(), server2) require.NoError(t, err) @@ -1186,7 +1186,8 @@ func TestClusterNodesGet(t *testing.T) { Kind: types.KindNode, SubKind: types.SubKindTeleportNode, ClusterName: clusterName, - Name: "server2", + Name: server2.GetName(), + Hostname: server2.GetHostname(), Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, Tunnel: false, SSHLogins: []string{pack.login},