diff --git a/lib/auth/auth.go b/lib/auth/auth.go index c240ad6fc585f..4a88c5e083603 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1508,12 +1508,11 @@ func (a *Server) runPeriodicOperations() { heartbeatsMissedByAuth.Inc() } + if srv.GetSubKind() != types.SubKindOpenSSHNode { + return false, nil + } // TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then. if !validServerHostname(srv.GetHostname()) { - if srv.GetSubKind() != types.SubKindOpenSSHNode { - return false, nil - } - logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname()) logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname") @@ -1527,6 +1526,17 @@ func (a *Server) runPeriodicOperations() { if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err) } + } else if oldHostname, ok := srv.GetLabel(replacedHostnameLabel); ok && validServerHostname(oldHostname) { + // If the hostname has been replaced by a sanitized version, revert it back to the original + // if the original is valid under the most recent rules. + logger := a.logger.With("server", srv.GetName(), "old_hostname", oldHostname, "sanitized_hostname", srv.GetHostname()) + if err := restoreSanitizedHostname(srv); err != nil { + logger.WarnContext(a.closeCtx, "failed to restore sanitized static SSH server hostname", "error", err) + return false, nil + } + if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { + log.Warnf("Failed to update node hostname: %v", err) + } } return false, nil @@ -5650,7 +5660,7 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error { const ( serverHostnameMaxLen = 256 - serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$` + serverHostnameRegexPattern = `^[a-zA-Z0-9]+[a-zA-Z0-9\.-]*$` replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname" ) @@ -5658,7 +5668,7 @@ var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern) // validServerHostname returns false if the hostname is longer than 256 characters or // does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also -// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol. +// cannot begin with a symbol. func validServerHostname(hostname string) bool { return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname) } @@ -5697,6 +5707,26 @@ func sanitizeHostname(server types.Server) error { return nil } +// restoreSanitizedHostname restores the original hostname of a server and removes the label. +func restoreSanitizedHostname(server types.Server) error { + oldHostname, ok := server.GetLabels()[replacedHostnameLabel] + // if the label is not present or the hostname is invalid under the most recent rules, do nothing. + if !ok || !validServerHostname(oldHostname) { + return nil + } + + switch s := server.(type) { + case *types.ServerV2: + // restore the original hostname and remove the label. + s.Spec.Hostname = oldHostname + delete(s.Metadata.Labels, replacedHostnameLabel) + default: + return trace.BadParameter("invalid server provided") + } + + return nil +} + // UpsertNode implements [services.Presence] by delegating to [Server.Services] // and potentially emitting a [usagereporter] event. func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) { diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index e4978e32e358a..8f535a1727588 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -4478,6 +4478,10 @@ func TestServerHostnameSanitization(t *testing.T) { name: "uuid dns hostname", hostname: uuid.NewString() + ".example.com", }, + { + name: "valid dns hostname with multi-dots", + hostname: "llama..example.com", + }, { name: "empty hostname", hostname: "", @@ -4488,11 +4492,6 @@ func TestServerHostnameSanitization(t *testing.T) { hostname: strings.Repeat("a", serverHostnameMaxLen*2), invalidHostname: true, }, - { - name: "invalid dns hostname", - hostname: "llama..example.com", - invalidHostname: true, - }, { name: "spaces in hostname", hostname: "the quick brown fox jumps over the lazy dog", @@ -4562,3 +4561,74 @@ func TestServerHostnameSanitization(t *testing.T) { }) } } + +func TestValidServerHostname(t *testing.T) { + t.Parallel() + tests := []struct { + name string + hostname string + want bool + }{ + { + name: "valid dns hostname", + hostname: "llama.example.com", + want: true, + }, + { + name: "valid friendly hostname", + hostname: "llama", + want: true, + }, + { + name: "uuid hostname", + hostname: uuid.NewString(), + want: true, + }, + { + name: "valid hostname with multi-dashes", + hostname: "llama--example.com", + want: true, + }, + { + name: "valid hostname with multi-dots", + hostname: "llama..example.com", + want: true, + }, + { + name: "valid hostname with numbers", + hostname: "llama9", + want: true, + }, + { + name: "hostname with invalid characters", + hostname: "llama?!$", + want: false, + }, + { + name: "super long hostname", + hostname: strings.Repeat("a", serverHostnameMaxLen*2), + want: false, + }, + { + name: "hostname with spaces", + hostname: "the quick brown fox jumps over the lazy dog", + want: false, + }, + { + name: "hostname with ;", + hostname: "llama;example.com", + want: false, + }, + { + name: "empty hostname", + hostname: "", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := validServerHostname(tt.hostname) + require.Equal(t, tt.want, got) + }) + } +}