Skip to content

Commit

Permalink
[v17] Sanitize SSH server hostnames (#49091)
Browse files Browse the repository at this point in the history
* Sanitize SSH server hostnames

Prevents any invalid and malicious hostnames, but replacing them with
known valid data already associated with the host. This was chosen
instead of rejecting to persist the server resource in an attempt to
continue providing access to the host in order to remedy the invalid
hostname.

Any servers that represent a Teleport ssh_service with an invalid
hostname will be replaced by the host UUID. Any static OpenSSH servers
will have invalid hostnames replaced with the address. This will continue
to allow the hosts to be dialable. In order to make these hosts
discoverable, the invalid hostname will be set in the
"teleport.internal/invalid-hostname" label.

Updates gravitational/teleport-private#1676.

* add and use internal update node method

* add test coverage for UpdateNode
  • Loading branch information
rosstimothy authored Nov 15, 2024
1 parent 60c78b2 commit 723f751
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 14 deletions.
83 changes: 83 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ import (
"log/slog"
"math/big"
insecurerand "math/rand"
"net"
"os"
"regexp"
"slices"
"sort"
"strconv"
Expand Down Expand Up @@ -1470,6 +1472,28 @@ func (a *Server) runPeriodicOperations() {
if services.NodeHasMissedKeepAlives(srv) {
missedKeepAliveCount++
}

// TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then.
if !validServerHostname(srv.GetHostname()) {
if srv.GetSubKind() != types.SubKindOpenSSHNode {
return false, nil
}

logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname())

logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname")
// Any existing static hosts will not have their
// hostname sanitized since they don't heartbeat.
if err := sanitizeHostname(srv); err != nil {
logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err)
return false, nil
}

if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err)
}
}

return false, nil
},
req,
Expand Down Expand Up @@ -5618,9 +5642,68 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error {
return nil
}

const (
serverHostnameMaxLen = 256
serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$`
replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname"
)

var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern)

// validServerHostname returns false if the hostname is longer than 256 characters or
// does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also
// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol.
func validServerHostname(hostname string) bool {
return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname)
}

func sanitizeHostname(server types.Server) error {
invalidHostname := server.GetHostname()

replacedHostname := server.GetName()
if server.GetSubKind() == types.SubKindOpenSSHNode {
host, _, err := net.SplitHostPort(server.GetAddr())
if err != nil || !validServerHostname(host) {
id, err := uuid.NewRandom()
if err != nil {
return trace.Wrap(err)
}

host = id.String()
}

replacedHostname = host
}

switch s := server.(type) {
case *types.ServerV2:
s.Spec.Hostname = replacedHostname

if s.Metadata.Labels == nil {
s.Metadata.Labels = map[string]string{}
}

s.Metadata.Labels[replacedHostnameLabel] = invalidHostname
default:
return trace.BadParameter("invalid server provided")
}

return nil
}

// UpsertNode implements [services.Presence] by delegating to [Server.Services]
// and potentially emitting a [usagereporter] event.
func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) {
if !validServerHostname(server.GetHostname()) {
a.logger.DebugContext(a.closeCtx, "sanitizing invalid server hostname",
"server", server.GetName(),
"hostname", server.GetHostname(),
)
if err := sanitizeHostname(server); err != nil {
return nil, trace.Wrap(err)
}
}

lease, err := a.Services.UpsertNode(ctx, server)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
138 changes: 128 additions & 10 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package auth

import (
"cmp"
"context"
"crypto/rand"
"crypto/x509"
Expand All @@ -34,7 +35,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
gocmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/license"
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestSessions(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, out.GetSSHPriv())
assert.Empty(t, out.GetTLSPriv())
assert.Empty(t, cmp.Diff(ws, out,
assert.Empty(t, gocmp.Diff(ws, out,
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
cmpopts.IgnoreFields(types.WebSessionSpecV2{}, "Priv", "TLSPriv")))

Expand Down Expand Up @@ -1655,7 +1656,7 @@ func TestServer_AugmentContextUserCertificates(t *testing.T) {
AssetTag: test.opts.DeviceExtensions.AssetTag,
CredentialId: test.opts.DeviceExtensions.CredentialID,
}
if diff := cmp.Diff(want, got); diff != "" {
if diff := gocmp.Diff(want, got); diff != "" {
t.Errorf("certEvent.Identity.DeviceExtensions mismatch (-want +got)\n%s", diff)
}
}
Expand Down Expand Up @@ -2301,12 +2302,12 @@ func TestServer_ExtendWebSession_deviceExtensions(t *testing.T) {
// Assert TLS extensions.
_, newIdentity := parseX509PEMAndIdentity(t, newSession.GetTLSCert())
wantExts := tlsca.DeviceExtensions(*deviceExts)
if diff := cmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
if diff := gocmp.Diff(wantExts, newIdentity.DeviceExtensions); diff != "" {
t.Errorf("newSession.TLSCert DeviceExtensions mismatch (-want +got)\n%s", diff)
}

// Assert SSH extensions.
if diff := cmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
if diff := gocmp.Diff(wantExts, parseSSHDeviceExtensions(t, newSession.GetPub())); diff != "" {
t.Errorf("newSession.Pub DeviceExtensions mismatch (-want +got)\n%s", diff)
}
})
Expand Down Expand Up @@ -2545,7 +2546,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) {
// Validate audit event.
lastEvent := p.mockEmitter.LastEvent()
require.IsType(t, &apievents.CertificateCreate{}, lastEvent)
require.Empty(t, cmp.Diff(
require.Empty(t, gocmp.Diff(
&apievents.CertificateCreate{
Metadata: apievents.Metadata{
Type: events.CertificateCreateEvent,
Expand Down Expand Up @@ -3801,15 +3802,15 @@ func compareDevices(t *testing.T, ignoreUpdateAndCounter bool, got []*types.MFAD
}

// Ignore LastUsed and SignatureCounter?
var opts []cmp.Option
var opts []gocmp.Option
if ignoreUpdateAndCounter {
opts = append(opts, cmp.FilterPath(func(path cmp.Path) bool {
opts = append(opts, gocmp.FilterPath(func(path gocmp.Path) bool {
p := path.String()
return p == "LastUsed" || p == "Device.Webauthn.SignatureCounter"
}, cmp.Ignore()))
}, gocmp.Ignore()))
}

if diff := cmp.Diff(want, got, opts...); diff != "" {
if diff := gocmp.Diff(want, got, opts...); diff != "" {
t.Errorf("compareDevices mismatch (-want +got):\n%s", diff)
}
}
Expand Down Expand Up @@ -4444,3 +4445,120 @@ func newGlobalNotificationWithExpiry(t *testing.T, title string, expires *timest

return &notification
}

// TestServerHostnameSanitization tests that persisting servers with
// "invalid" hostnames results in the hostname being sanitized and the
// illegal name being placed in a label.
func TestServerHostnameSanitization(t *testing.T) {
t.Parallel()
ctx := context.Background()
srv, err := NewTestAuthServer(TestAuthServerConfig{Dir: t.TempDir()})
require.NoError(t, err)

cases := []struct {
name string
hostname string
addr string
invalidHostname bool
invalidAddr bool
}{
{
name: "valid dns hostname",
hostname: "llama.example.com",
},
{
name: "valid friendly hostname",
hostname: "llama",
},
{
name: "uuid hostname",
hostname: uuid.NewString(),
},
{
name: "uuid dns hostname",
hostname: uuid.NewString() + ".example.com",
},
{
name: "empty hostname",
hostname: "",
invalidHostname: true,
},
{
name: "exceptionally long hostname",
hostname: strings.Repeat("a", serverHostnameMaxLen*2),
invalidHostname: true,
},
{
name: "invalid dns hostname",
hostname: "llama..example.com",
invalidHostname: true,
},
{
name: "spaces in hostname",
hostname: "the quick brown fox jumps over the lazy dog",
invalidHostname: true,
},
{
name: "invalid addr",
hostname: "..",
addr: "..:2345",
invalidHostname: true,
invalidAddr: true,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
for _, subKind := range []string{types.KindNode, types.SubKindOpenSSHNode} {
t.Run(subKind, func(t *testing.T) {
server := &types.ServerV2{
Kind: types.KindNode,
SubKind: subKind,
Metadata: types.Metadata{
Name: uuid.NewString(),
},
Spec: types.ServerSpecV2{
Hostname: test.hostname,
Addr: cmp.Or(test.addr, "abcd:1234"),
},
}
if subKind == types.KindNode {
server.SubKind = ""
}

_, err = srv.AuthServer.UpsertNode(ctx, server)
require.NoError(t, err)

replacedValue, _ := server.GetLabel("teleport.internal/invalid-hostname")
if !test.invalidHostname {
assert.Equal(t, test.hostname, server.GetHostname())
assert.Empty(t, replacedValue)
return
}

assert.Equal(t, test.hostname, replacedValue)
switch subKind {
case types.SubKindOpenSSHNode:
host, _, err := net.SplitHostPort(server.GetAddr())
assert.NoError(t, err)
if !test.invalidAddr {
// If the address is valid, then the hostname should be set
// to the host of the addr field.
assert.Equal(t, host, server.GetHostname())
} else {
// If the address is not valid, then the hostname should be
// set to a UUID.
assert.NotEqual(t, host, server.GetHostname())
assert.NotEqual(t, server.GetName(), server.GetHostname())

_, err := uuid.Parse(server.GetHostname())
require.NoError(t, err)
}
default:
assert.Equal(t, server.GetName(), server.GetHostname())
}
})
}
})
}
}
4 changes: 2 additions & 2 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2955,9 +2955,9 @@ func TestNodesCRUD(t *testing.T) {
require.NoError(t, err)

// node1 and node2 will be added to default namespace
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{}, nil)
node1, err := types.NewServerWithLabels("node1", types.KindNode, types.ServerSpecV2{Hostname: "node1"}, nil)
require.NoError(t, err)
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{}, nil)
node2, err := types.NewServerWithLabels("node2", types.KindNode, types.ServerSpecV2{Hostname: "node2"}, nil)
require.NoError(t, err)

t.Run("CreateNode", func(t *testing.T) {
Expand Down
28 changes: 28 additions & 0 deletions lib/services/local/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,34 @@ func (s *PresenceService) UpsertNode(ctx context.Context, server types.Server) (
}, nil
}

// UpdateNode conditionally updates the provided server.
func (s *PresenceService) UpdateNode(ctx context.Context, server types.Server) (types.Server, error) {
if server.GetNamespace() == "" {
server.SetNamespace(apidefaults.Namespace)
}

if n := server.GetNamespace(); n != apidefaults.Namespace {
return nil, trace.BadParameter("cannot place node in namespace %q, custom namespaces are deprecated", n)
}
rev := server.GetRevision()
value, err := services.MarshalServer(server)
if err != nil {
return nil, trace.Wrap(err)
}
lease, err := s.ConditionalUpdate(ctx, backend.Item{
Key: backend.NewKey(nodesPrefix, server.GetNamespace(), server.GetName()),
Value: value,
Expires: server.Expiry(),
Revision: rev,
})
if err != nil {
return nil, trace.Wrap(err)
}

server.SetRevision(lease.Revision)
return server, nil
}

// GetAuthServers returns a list of registered servers
func (s *PresenceService) GetAuthServers() ([]types.Server, error) {
return s.getServers(context.TODO(), types.KindAuthServer, authServersPrefix)
Expand Down
25 changes: 25 additions & 0 deletions lib/services/local/presence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,31 @@ func TestNodeCRUD(t *testing.T) {
require.NoError(t, err)
})

t.Run("UpdateNode", func(t *testing.T) {
node1, err = presence.GetNode(ctx, apidefaults.Namespace, node1.GetName())
require.NoError(t, err)
node1.SetAddr("1.2.3.4:8080")

node2, err = presence.GetNode(ctx, apidefaults.Namespace, node2.GetName())
require.NoError(t, err)

node1, err = presence.UpdateNode(ctx, node1)
require.NoError(t, err)
require.Equal(t, "1.2.3.4:8080", node1.GetAddr())

rev := node2.GetRevision()
node2.SetAddr("1.2.3.4:9090")
node2.SetRevision(node1.GetRevision())

_, err = presence.UpdateNode(ctx, node2)
require.True(t, trace.IsCompareFailed(err))
node2.SetRevision(rev)

node2, err = presence.UpdateNode(ctx, node2)
require.NoError(t, err)
require.Equal(t, "1.2.3.4:9090", node2.GetAddr())
})

// Run NodeGetters in nested subtests to allow parallelization.
t.Run("NodeGetters", func(t *testing.T) {
t.Run("GetNodes", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions lib/services/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,5 @@ type PresenceInternal interface {
UpsertHostUserInteractionTime(ctx context.Context, name string, loginTime time.Time) error
GetHostUserInteractionTime(ctx context.Context, name string) (time.Time, error)
UpsertReverseTunnelV2(ctx context.Context, tunnel types.ReverseTunnel) (types.ReverseTunnel, error)
UpdateNode(ctx context.Context, server types.Server) (types.Server, error)
}
Loading

0 comments on commit 723f751

Please sign in to comment.