Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with instances sharing config dir. #70

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/antifuchs/tsnsrv
go 1.21

require (
github.com/gofrs/flock v0.8.1
github.com/peterbourgon/ff/v3 v3.4.0
github.com/prometheus/client_golang v1.17.0
github.com/stretchr/testify v1.8.4
Expand Down Expand Up @@ -80,7 +81,7 @@ require (
golang.org/x/mod v0.11.0 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/term v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/time v0.3.0 // indirect
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down Expand Up @@ -239,8 +241,8 @@ golang.org/x/sys v0.0.0-20210301091718-77cc2087c03b/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
Expand Down
10 changes: 9 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func tailnetSrvFromArgs(args []string) (*validTailnetSrv, *ffcli.Command, error)
fs.DurationVar(&s.Timeout, "timeout", 1*time.Minute, "Timeout connecting to the tailnet")
fs.Var(&s.AllowedPrefixes, "prefix", "Allowed URL prefixes; if none is set, all prefixes are allowed")
fs.BoolVar(&s.StripPrefix, "stripPrefix", true, "Strip prefixes that matched; best set to false if allowing multiple prefixes")
fs.StringVar(&s.StateDir, "stateDir", os.Getenv("TS_STATE_DIR"), "Directory containing the persistent tailscale status files. Can also be set by $TS_STATE_DIR; this option takes precedence.")
fs.StringVar(&s.StateDir, "stateDir", "", "Directory containing the persistent tailscale status files. Can also be set by $TS_STATE_DIR; this option takes precedence.")
fs.StringVar(&s.AuthkeyPath, "authkeyPath", "", "File containing a tailscale auth key. Key is assumed to be in $TS_AUTHKEY in absence of this option.")
fs.BoolVar(&s.InsecureHTTPS, "insecureHTTPS", false, "Disable TLS certificate validation on upstream")
fs.DurationVar(&s.WhoisTimeout, "whoisTimeout", 1*time.Second, "Maximum amount of time to spend looking up client identities")
Expand All @@ -123,6 +123,14 @@ func tailnetSrvFromArgs(args []string) (*validTailnetSrv, *ffcli.Command, error)
if err := root.Parse(args); err != nil {
return nil, root, fmt.Errorf("could not parse args: %w", err)
}

// Figure out the state directory
stateDir, err := NewStateDir(s.Name, s.StateDir).Compute()
if err != nil {
return nil, nil, fmt.Errorf("unable to compute state dir: %w", err)
}
s.StateDir = stateDir

valid, err := s.validate(root.FlagSet.Args())
if err != nil {
return nil, root, fmt.Errorf("failed to validate args: %w", err)
Expand Down
159 changes: 159 additions & 0 deletions state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package main

import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path"
"strings"
"time"

"github.com/gofrs/flock"
)

type StateDir struct {
machineName string
stateDirFlag string
getEnv func(string) string
userConfigDir func() (string, error)
dirExists func(string) (bool, error)
readFileString func(string) (string, error)
writeFileString func(string, string) error
}

func NewStateDir(machineName, stateDirFlag string) StateDir {
return StateDir{
machineName: machineName,
stateDirFlag: stateDirFlag,
getEnv: os.Getenv,
userConfigDir: os.UserConfigDir,
dirExists: dirExists,
readFileString: readFileString,
writeFileString: writeFileString,
}
}

func (sd StateDir) Compute() (string, error) {
// Set command line flag
if sd.stateDirFlag != "" {
return sd.stateDirFlag, nil
}

// Set TS_STATE_DIR env var
tsStateDirEnv := sd.getEnv("TS_STATE_DIR")
if tsStateDirEnv != "" {
return tsStateDirEnv, nil
}

// Looking for legacy tsnet-tsnsrv configuration directory
userConfigDir, err := sd.userConfigDir()
if err != nil {
return "", fmt.Errorf("unable to find user config directory. %w", err)
}
legacyTsnetConfigDir := path.Join(userConfigDir, "tsnet-tsnsrv")
legacyTsnetDirExists, err := sd.dirExists(legacyTsnetConfigDir)
if err != nil {
return "", fmt.Errorf("unable to determine existence of legacy tsnet config directory. %w", err)
}

// The tsnet-tsnet directory doesn't exist so we can just create a unique configuration directory for the given
// machine name.
if !legacyTsnetDirExists {
return path.Join(userConfigDir, fmt.Sprintf("tsnet-tsnsrv-%s", sd.machineName)), nil
}

// The tsnet-tsnet directory does exist reach the machine name file and see if they match
machineNamePath := path.Join(legacyTsnetConfigDir, "machine-name")
readName, err := sd.readFileString(machineNamePath)
if errors.Is(err, fs.ErrNotExist) {
err = sd.writeFileString(machineNamePath, sd.machineName)
if err != nil {
return "", fmt.Errorf("unable to write machine name to legacy config dir. %w", err)
}

return legacyTsnetConfigDir, nil
}
if err != nil {
return "", fmt.Errorf("unable to read legacy machine-name file. %w", err)
}

if strings.TrimSpace(readName) == sd.machineName {
return legacyTsnetConfigDir, nil
}

return path.Join(userConfigDir, fmt.Sprintf("tsnet-tsnsrv-%s", sd.machineName)), nil
}

func lockFilePath() string {
return path.Join(os.TempDir(), "tsnsrv.lock")
}

var errTryLockTimeout = errors.New("timeout trying to get the file lock")
var errTryLockUnlocked = errors.New("lock file remained unlocked")

func lockContext(ctx context.Context) context.Context {
ctx, _ = context.WithTimeoutCause(ctx, time.Second*5, errTryLockTimeout)
return ctx
}

func tryLock(ctx context.Context, readLock bool) (func(), error) {
lockFile := lockFilePath()
lock := flock.New(lockFile)
ctx = lockContext(ctx)
lockFn := lock.TryLockContext
if readLock {
lockFn = lock.TryRLockContext
}

locked, err := lockFn(ctx, time.Millisecond*100)
if errors.Is(err, errTryLockTimeout) {
return nil, fmt.Errorf("timeout trying to get lock %s another process is using it. %w", lockFile, errTryLockTimeout)
}
if err != nil {
return nil, fmt.Errorf("trying to lock %s. %w", lockFile, err)
}
if !locked {
return nil, fmt.Errorf("unable to get lock %s. %w", lockFile, errTryLockUnlocked)
}

return func() { _ = lock.Unlock() }, nil
}

func readFileString(file string) (string, error) {
unlocker, err := tryLock(context.Background(), true)
if err != nil {
return "", err
}
defer unlocker()

bytes, err := os.ReadFile(file)
return string(bytes), err
}

func writeFileString(file, contents string) error {
unlocker, err := tryLock(context.Background(), false)
if err != nil {
return err
}
defer unlocker()

err = os.WriteFile(file, []byte(contents), 0600)
if err != nil {
return fmt.Errorf("unable to write %s file. %w", file, err)
}
return nil
}

func dirExists(dir string) (bool, error) {
_, err := os.Stat(dir)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}

return false, fmt.Errorf("can't stat directory %s to see if it exists. %w", dir, err)
}
131 changes: 131 additions & 0 deletions state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package main

import (
"fmt"
"io/fs"
"path"
"testing"

"github.com/stretchr/testify/require"
)

func initialState() StateDir {
sd := NewStateDir("machine-name", "")
sd.getEnv = func(string) string { return "" }
sd.userConfigDir = func() (string, error) { return "", nil }
sd.dirExists = func(string) (bool, error) { return false, nil }
sd.readFileString = func(string) (string, error) { return "", nil }
sd.writeFileString = func(string, string) error { return nil }

return sd
}

// Ensure that the -stateDir flag is used for selecting the state directory.
func TestStateDirFlag_IsUsedIfSet(t *testing.T) {
t.Parallel()

const stateDirFlag = "some path"

sd := initialState()
sd.stateDirFlag = stateDirFlag

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, stateDirFlag, stateDir)
}

// Ensure that the TS_STATE_DIR environment variable is used for selecting the state directory.
func TestTSSTATEDIREnvVarIsUsedIfSet(t *testing.T) {
t.Parallel()

const stateDirEnv = "some path"

sd := initialState()
sd.getEnv = func(string) string { return stateDirEnv }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, stateDirEnv, stateDir)
}

// Ensure that the tsnet-tsnsrv is used if it exists and the machine_name file contents match the -name argument.
func TestTsnetTsnsrvDirIsUsedIfExistsAndMachineNameMatches(t *testing.T) {
t.Parallel()

const userConfigDir = "/home/somedir/.config/"
const legacyTsnetConfigDir = "/home/somedir/.config/tsnet-tsnsrv"

sd := initialState()
sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return true, nil }
sd.readFileString = func(file string) (string, error) { return sd.machineName, nil }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, legacyTsnetConfigDir, stateDir)
}

// Ensure that the machine_name file is created in tsnet-tsnsrv if it doesn't exist.
func TestMachineNameFileIsCreatedIfNeeded(t *testing.T) {
t.Parallel()

const userConfigDir = "/home/somedir/.config/"
const legacyTsnetConfigDir = "/home/somedir/.config/tsnet-tsnsrv"
machineNameFile := path.Join(legacyTsnetConfigDir, "machine-name")
writeFileStringCalled := false

sd := initialState()
sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return true, nil }
sd.readFileString = func(file string) (string, error) { return "", fs.ErrNotExist }
sd.writeFileString = func(file, contents string) error {
require.Equal(t, machineNameFile, file)
require.Equal(t, sd.machineName, contents)
writeFileStringCalled = true
return nil
}

stateDir, err := sd.Compute()

require.True(t, writeFileStringCalled)
require.NoError(t, err)
require.Equal(t, legacyTsnetConfigDir, stateDir)
}

// Ensure that tsnet-tsnsrv-<name> is used if a tsnet-tsnsrv directory doesn't exist
func TestTsnetTsnsrvNameIsUsedIfLegacyDirDoesntExist(t *testing.T) {
t.Parallel()

sd := initialState()
const userConfigDir = "/home/somedir/.config/"
newTsnetConfigDir := fmt.Sprintf("/home/somedir/.config/tsnet-tsnsrv-%s", sd.machineName)

sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return false, nil }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, newTsnetConfigDir, stateDir)
}

// Ensure that tsnet-tsnsrv-<name> is used if the machine_name doesn't match.
func TestTsnetTsnsrvNameIsUsedIfMachineNameDoesntMatch(t *testing.T) {
t.Parallel()

sd := initialState()
const userConfigDir = "/home/somedir/.config/"
newTsnetConfigDir := fmt.Sprintf("/home/somedir/.config/tsnet-tsnsrv-%s", sd.machineName)

sd.userConfigDir = func() (string, error) { return userConfigDir, nil }
sd.dirExists = func(dir string) (bool, error) { return true, nil }
sd.readFileString = func(file string) (string, error) { return "not-a-match", nil }

stateDir, err := sd.Compute()

require.NoError(t, err)
require.Equal(t, newTsnetConfigDir, stateDir)
}
Loading