Skip to content

Commit

Permalink
Prevent overwriting existing host_uuid file (#48012)
Browse files Browse the repository at this point in the history
In some circumstances, multiple Teleport processes may be trying
to write the host_uuid file in the same data directory simultaneously.
The last of the writers would win, and any process using a host
UUID that did not match what ended up on disk could get into a perpertual
state of being unable to connect to the cluster.

To avoid the raciness, the host_uuid file writing process is no
longer a blind upsert. Instead, special care is taken to ensure
that there can only be a single writer, and that any subsequent
updates to the file are aborted and the first value written is
used instead.
  • Loading branch information
rosstimothy authored Nov 4, 2024
1 parent 5d7eb65 commit dc88db7
Show file tree
Hide file tree
Showing 14 changed files with 335 additions and 141 deletions.
15 changes: 8 additions & 7 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ import (
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/cert"
"github.com/gravitational/teleport/lib/utils/hostid"
logutils "github.com/gravitational/teleport/lib/utils/log"
vc "github.com/gravitational/teleport/lib/versioncontrol"
"github.com/gravitational/teleport/lib/versioncontrol/endpoint"
Expand Down Expand Up @@ -2934,7 +2935,7 @@ func (process *TeleportProcess) initSSH() error {
storagePresence := local.NewPresenceService(process.storage.BackendStorage)

// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
serverID, err := hostid.ReadOrCreateFile(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -4439,7 +4440,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}

// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
serverID, err := hostid.ReadOrCreateFile(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -6498,7 +6499,7 @@ func readOrGenerateHostID(ctx context.Context, cfg *servicecfg.Config, kubeBacke
if err := persistHostIDToStorages(ctx, cfg, kubeBackend); err != nil {
return trace.Wrap(err)
}
} else if kubeBackend != nil && utils.HostUUIDExistsLocally(cfg.DataDir) {
} else if kubeBackend != nil && hostid.ExistsLocally(cfg.DataDir) {
// This case is used when loading a Teleport pre-11 agent with storage attached.
// In this case, we have to copy the "host_uuid" from the agent to the secret
// in case storage is removed later.
Expand Down Expand Up @@ -6537,14 +6538,14 @@ func readHostIDFromStorages(ctx context.Context, dataDir string, kubeBackend kub
}
// Even if running in Kubernetes fallback to local storage if `host_uuid` was
// not found in secret.
hostID, err := utils.ReadHostUUID(dataDir)
hostID, err := hostid.ReadFile(dataDir)
return hostID, trace.Wrap(err)
}

// persistHostIDToStorages writes the cfg.HostUUID to local data and to
// Kubernetes Secret if this process is running on a Kubernetes Cluster.
func persistHostIDToStorages(ctx context.Context, cfg *servicecfg.Config, kubeBackend kubernetesBackend) error {
if err := utils.WriteHostUUID(cfg.DataDir, cfg.HostUUID); err != nil {
if err := hostid.WriteFile(cfg.DataDir, cfg.HostUUID); err != nil {
if errors.Is(err, fs.ErrPermission) {
cfg.Logger.ErrorContext(ctx, "Teleport does not have permission to write to the data directory. Ensure that you are running as a user with appropriate permissions.", "data_dir", cfg.DataDir)
}
Expand All @@ -6563,7 +6564,7 @@ func persistHostIDToStorages(ctx context.Context, cfg *servicecfg.Config, kubeBa
// loadHostIDFromKubeSecret reads the host_uuid from the Kubernetes secret with
// the expected key: `/host_uuid`.
func loadHostIDFromKubeSecret(ctx context.Context, kubeBackend kubernetesBackend) (string, error) {
item, err := kubeBackend.Get(ctx, backend.NewKey(utils.HostUUIDFile))
item, err := kubeBackend.Get(ctx, backend.NewKey(hostid.FileName))
if err != nil {
return "", trace.Wrap(err)
}
Expand All @@ -6576,7 +6577,7 @@ func writeHostIDToKubeSecret(ctx context.Context, kubeBackend kubernetesBackend,
_, err := kubeBackend.Put(
ctx,
backend.Item{
Key: backend.NewKey(utils.HostUUIDFile),
Key: backend.NewKey(hostid.FileName),
Value: []byte(id),
},
)
Expand Down
3 changes: 2 additions & 1 deletion lib/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -1167,7 +1168,7 @@ func Test_readOrGenerateHostID(t *testing.T) {
dataDir := t.TempDir()
// write host_uuid file to temp dir.
if len(tt.args.hostIDContent) > 0 {
err := utils.WriteHostUUID(dataDir, tt.args.hostIDContent)
err := hostid.WriteFile(dataDir, tt.args.hostIDContent)
require.NoError(t, err)
}

Expand Down
3 changes: 2 additions & 1 deletion lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
"github.com/gravitational/teleport/lib/sshutils/x11"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

var log = logrus.WithFields(logrus.Fields{
Expand Down Expand Up @@ -724,7 +725,7 @@ func New(
options ...ServerOption,
) (*Server, error) {
// read the host UUID:
uuid, err := utils.ReadOrMakeHostUUID(dataDir)
uuid, err := hostid.ReadOrCreateFile(dataDir)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 4 additions & 3 deletions lib/teleterm/services/connectmycomputer/connectmycomputer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/teleterm/clusters"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

type RoleSetup struct {
Expand Down Expand Up @@ -395,7 +396,7 @@ func (n *NodeJoinWait) getNodeNameFromHostUUIDFile(ctx context.Context, cluster
// the file is empty.
//
// Here we need to be able to distinguish between both of those two cases.
out, err := utils.ReadPath(utils.GetHostUUIDPath(dataDir))
out, err := utils.ReadPath(hostid.GetPath(dataDir))
if err != nil {
if trace.IsNotFound(err) {
continue
Expand Down Expand Up @@ -536,7 +537,7 @@ type NodeDelete struct {

// Run grabs the host UUID of an agent from a disk and deletes the node with that name.
func (n *NodeDelete) Run(ctx context.Context, presence Presence, cluster *clusters.Cluster) error {
hostUUID, err := utils.ReadHostUUID(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
hostUUID, err := hostid.ReadFile(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
if trace.IsNotFound(err) {
return nil
}
Expand Down Expand Up @@ -585,7 +586,7 @@ type NodeName struct {

// Get returns the host UUID of the agent from a disk.
func (n *NodeName) Get(cluster *clusters.Cluster) (string, error) {
hostUUID, err := utils.ReadHostUUID(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
hostUUID, err := hostid.ReadFile(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
return hostUUID, trace.Wrap(err)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/teleterm/api/uri"
"github.com/gravitational/teleport/lib/teleterm/clusters"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

func TestRoleSetupRun_WithNonLocalUser(t *testing.T) {
Expand Down Expand Up @@ -472,7 +472,7 @@ func mustMakeHostUUIDFile(t *testing.T, agentsDir string, profileName string) st
err = os.MkdirAll(dataDir, agentsDirStat.Mode())
require.NoError(t, err)

hostUUID, err := utils.ReadOrMakeHostUUID(dataDir)
hostUUID, err := hostid.ReadOrCreateFile(dataDir)
require.NoError(t, err)

return hostUUID
Expand Down
61 changes: 61 additions & 0 deletions lib/utils/hostid/hostid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package hostid

import (
"errors"
"io/fs"
"path/filepath"
"strings"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/utils"
)

const (
// FileName is the file name where the host UUID file is stored
FileName = "host_uuid"
)

// GetPath returns the path to the host UUID file given the data directory.
func GetPath(dataDir string) string {
return filepath.Join(dataDir, FileName)
}

// ExistsLocally checks if dataDir/host_uuid file exists in local storage.
func ExistsLocally(dataDir string) bool {
_, err := ReadFile(dataDir)
return err == nil
}

// ReadFile reads host UUID from the file in the data dir
func ReadFile(dataDir string) (string, error) {
out, err := utils.ReadPath(GetPath(dataDir))
if err != nil {
if errors.Is(err, fs.ErrPermission) {
//do not convert to system error as this loses the ability to compare that it is a permission error
return "", trace.Wrap(err)
}
return "", trace.ConvertSystemError(err)
}
id := strings.TrimSpace(string(out))
if id == "" {
return "", trace.NotFound("host uuid is empty")
}
return id, nil
}
113 changes: 113 additions & 0 deletions lib/utils/hostid/hostid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//go:build !windows

// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package hostid_test

import (
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}

func TestReadOrCreate(t *testing.T) {
t.Parallel()

dir := t.TempDir()

var wg errgroup.Group
concurrency := 10
ids := make([]string, concurrency)
barrier := make(chan struct{})

for i := 0; i < concurrency; i++ {
wg.Go(func() error {
<-barrier
id, err := hostid.ReadOrCreateFile(dir)
ids[i] = id
return err
})
}

close(barrier)

require.NoError(t, wg.Wait())
require.Equal(t, slices.Repeat([]string{ids[0]}, concurrency), ids)
}

func TestIdempotence(t *testing.T) {
t.Parallel()

// call twice, get same result
dir := t.TempDir()
id, err := hostid.ReadOrCreateFile(dir)
require.Len(t, id, 36)
require.NoError(t, err)
uuidCopy, err := hostid.ReadOrCreateFile(dir)
require.NoError(t, err)
require.Equal(t, id, uuidCopy)
}

func TestBadLocation(t *testing.T) {
t.Parallel()

// call with a read-only dir, make sure to get an error
id, err := hostid.ReadOrCreateFile("/bad-location")
require.Empty(t, id)
require.Error(t, err)
require.Regexp(t, "^.*no such file or directory.*$", err.Error())
}

func TestIgnoreWhitespace(t *testing.T) {
t.Parallel()

// newlines are getting ignored
dir := t.TempDir()
id := fmt.Sprintf("%s\n", uuid.NewString())
err := os.WriteFile(filepath.Join(dir, hostid.FileName), []byte(id), 0666)
require.NoError(t, err)
out, err := hostid.ReadFile(dir)
require.NoError(t, err)
require.Equal(t, strings.TrimSpace(id), out)
}

func TestRegenerateEmpty(t *testing.T) {
t.Parallel()

// empty UUID in file is regenerated
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, hostid.FileName), nil, 0666)
require.NoError(t, err)
out, err := hostid.ReadOrCreateFile(dir)
require.NoError(t, err)
require.Len(t, out, 36)
}
Loading

0 comments on commit dc88db7

Please sign in to comment.