diff --git a/runner/cmd/runner/cmd.go b/runner/cmd/runner/cmd.go index b217f0c54..8428d456e 100644 --- a/runner/cmd/runner/cmd.go +++ b/runner/cmd/runner/cmd.go @@ -13,6 +13,7 @@ var Version string func App() { var paths struct{ tempDir, homeDir, workingDir string } var httpPort int + var sshPort int var logLevel int app := &cli.App{ @@ -57,9 +58,15 @@ func App() { Value: 10999, Destination: &httpPort, }, + &cli.IntFlag{ + Name: "ssh-port", + Usage: "Set the ssh port", + Required: true, + Destination: &sshPort, + }, }, Action: func(c *cli.Context) error { - err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel, Version) + err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, sshPort, logLevel, Version) if err != nil { return cli.Exit(err, 1) } diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 72f7e5f8c..173b12a1d 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -19,7 +19,7 @@ func main() { App() } -func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int, version string) error { +func start(tempDir string, homeDir string, workingDir string, httpPort int, sshPort int, logLevel int, version string) error { if err := os.MkdirAll(tempDir, 0o755); err != nil { return tracerr.Errorf("Failed to create temp directory: %w", err) } @@ -38,7 +38,7 @@ func start(tempDir string, homeDir string, workingDir string, httpPort int, logL log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) - server, err := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort), version) + server, err := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort), sshPort, version) if err != nil { return tracerr.Errorf("Failed to create server: %w", err) } diff --git a/runner/go.mod b/runner/go.mod index caa65de76..2eeee8539 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -84,5 +84,6 @@ require ( require ( github.com/codeclysm/extract/v3 v3.1.1 github.com/gorilla/websocket v1.5.1 + github.com/prometheus/procfs v0.15.1 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/runner/go.sum b/runner/go.sum index cea29e148..c3cdaf43f 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -145,6 +145,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= diff --git a/runner/internal/connections/connections.go b/runner/internal/connections/connections.go new file mode 100644 index 000000000..74214a4d8 --- /dev/null +++ b/runner/internal/connections/connections.go @@ -0,0 +1,130 @@ +package connections + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/prometheus/procfs" +) + +const connStateEstablished = 1 + +type connection struct { + fromAddr string + fromPort uint64 +} + +type trackingInfo struct { + firstSeenAt time.Time +} + +type ConnectionTrackerConfig struct { + Port uint64 + MinConnDuration time.Duration + Procfs procfs.FS +} + +// Tracks TCP connections to a specified port. +type ConnectionTracker struct { + cfg ConnectionTrackerConfig + connections map[connection]trackingInfo + lastConnectionAt *time.Time + lastCheckedAt *time.Time + stopChan chan struct{} + mu sync.RWMutex +} + +func NewConnectionTracker(cfg ConnectionTrackerConfig) *ConnectionTracker { + tracker := ConnectionTracker{ + cfg: cfg, + connections: make(map[connection]trackingInfo), + lastConnectionAt: nil, + lastCheckedAt: nil, + stopChan: make(chan struct{}), + mu: sync.RWMutex{}, + } + return &tracker +} + +// Returns the number of seconds since the last connection was closed or +// since tracking started. If tracking hasn't started yet, returns 0. +func (t *ConnectionTracker) GetNoConnectionsSecs() int64 { + t.mu.RLock() + defer t.mu.RUnlock() + if t.lastConnectionAt == nil || t.lastCheckedAt == nil { + return 0 + } + return int64(t.lastCheckedAt.Sub(*t.lastConnectionAt).Seconds()) +} + +func (t *ConnectionTracker) Track(ticker <-chan time.Time) { + for { + select { + case now := <-ticker: + t.updateConnections(now) + case <-t.stopChan: + return + } + } +} + +func (t *ConnectionTracker) Stop() { + t.stopChan <- struct{}{} +} + +func (t *ConnectionTracker) updateConnections(now time.Time) { + currentConnections, err := t.getCurrentConnections() + if err != nil { + log.Error(context.TODO(), "Failed to retrieve connections: %v", err) + return + } + t.mu.Lock() + defer t.mu.Unlock() + // evict closed connections + for conn := range t.connections { + if _, ok := currentConnections[conn]; !ok { + delete(t.connections, conn) + } + } + // add new connections + for conn := range currentConnections { + if _, ok := t.connections[conn]; !ok { + t.connections[conn] = trackingInfo{firstSeenAt: now} + } + } + // update lastConnectionAt + for _, connInfo := range t.connections { + if now.Sub(connInfo.firstSeenAt) > t.cfg.MinConnDuration { + t.lastConnectionAt = &now + break + } + } + if t.lastConnectionAt == nil { // first call to updateConnections + t.lastConnectionAt = &now + } + t.lastCheckedAt = &now +} + +func (t *ConnectionTracker) getCurrentConnections() (map[connection]struct{}, error) { + connections := make(map[connection]struct{}) + netTCP, err := t.cfg.Procfs.NetTCP() + if err != nil { + return nil, fmt.Errorf("Failed to retrieve IPv4 network connections: %w", err) + } + netTCP6, err := t.cfg.Procfs.NetTCP6() + if err != nil { + return nil, fmt.Errorf("Failed to retrieve IPv6 network connections: %w", err) + } + for _, conn := range append(netTCP, netTCP6...) { + if conn.LocalPort == t.cfg.Port && conn.St == connStateEstablished { + connections[connection{ + fromAddr: conn.RemAddr.String(), + fromPort: conn.RemPort, + }] = struct{}{} + } + } + return connections, nil +} diff --git a/runner/internal/connections/connections_test.go b/runner/internal/connections/connections_test.go new file mode 100644 index 000000000..4d46fb543 --- /dev/null +++ b/runner/internal/connections/connections_test.go @@ -0,0 +1,94 @@ +package connections + +import ( + "io/fs" + "os" + "testing" + "time" + + "github.com/prometheus/procfs" + "github.com/stretchr/testify/assert" +) + +func TestConnectionTracker(t *testing.T) { + procfsDir := t.TempDir() + proc, err := procfs.NewFS(procfsDir) + assert.NoError(t, err) + err = os.Mkdir(procfsDir+"/net", os.ModePerm) + assert.NoError(t, err) + tracker := NewConnectionTracker(ConnectionTrackerConfig{ + Port: 4096, + MinConnDuration: 5 * time.Second, + Procfs: proc, + }) + ticker := make(chan time.Time) + // Open sockets on ports 53 and 4096 + established connection to port 53 (irrelevant) + noConnTcp := ` sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + 0: 3500007F:0035 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0 + 1: 00000000:1000 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0 + 2: 3500007F:0035 0100007F:1234 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0 +` + noConnTcp6 := ` sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + 0: 00000000000000000000000000000000:0035 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0 + 1: 00000000000000000000000000000000:1000 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0 + 2: 00000000000000000000000000000000:0035 00000000000000000000000001000000:1234 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0 +` + // Established connection to port 4096 (relevant) + connTcp := " 3: 00000000:1000 0100007F:4321 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0" + connTcp6 := " 3: 00000000000000000000000000000000:1000 00000000000000000000000001000000:4321 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0" + + // Tracking did not start yet + // Returns 0 secs + assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs()) + + go tracker.Track(ticker) + defer tracker.Stop() + assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs()) + + // There is a 2-second-old connection + // Returns 2 secs (the connection doesn't count as it's < MinConnDuration) + writeProcfs(t, procfsDir, noConnTcp+connTcp, noConnTcp6) + tick := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ticker <- tick + wait() + tick = tick.Add(2 * time.Second) + ticker <- tick + wait() + assert.Equal(t, int64(2), tracker.GetNoConnectionsSecs()) + + // There is a 6-second-old connection + // Returns 0 secs (the connection is >= MinConnDuration) + tick = tick.Add(4 * time.Second) + ticker <- tick + wait() + assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs()) + + // The connection is closed and there are no connections for 15 secs. + // Returns 15 secs + writeProcfs(t, procfsDir, noConnTcp, noConnTcp6) + tick = tick.Add(15 * time.Second) + ticker <- tick + wait() + assert.Equal(t, int64(15), tracker.GetNoConnectionsSecs()) + + // There is a 7-second-old connection over IPv6 + // Returns 0 secs (the connection is >= MinConnDuration) + writeProcfs(t, procfsDir, noConnTcp, noConnTcp6+connTcp6) + tick = tick.Add(1 * time.Second) + ticker <- tick + tick = tick.Add(7 * time.Second) + ticker <- tick + wait() + assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs()) +} + +func writeProcfs(t *testing.T, procfsDir, tcp, tcp6 string) { + err := os.WriteFile(procfsDir+"/net/tcp", []byte(tcp), fs.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(procfsDir+"/net/tcp6", []byte(tcp6), fs.ModePerm) + assert.NoError(t, err) +} + +func wait() { + time.Sleep(30 * time.Millisecond) +} diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 0f3373601..6b4b6374d 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -18,10 +18,12 @@ import ( "github.com/creack/pty" "github.com/dstackai/dstack/runner/consts" + "github.com/dstackai/dstack/runner/internal/connections" "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/schemas" "github.com/dstackai/dstack/runner/internal/types" + "github.com/prometheus/procfs" ) type RunExecutor struct { @@ -44,10 +46,11 @@ type RunExecutor struct { runnerLogs *appendWriter timestamp *MonotonicTimestamp - killDelay time.Duration + killDelay time.Duration + connectionTracker *connections.ConnectionTracker } -func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExecutor, error) { +func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort int) (*RunExecutor, error) { mu := &sync.RWMutex{} timestamp := NewMonotonicTimestamp() user, err := osuser.Current() @@ -58,6 +61,15 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExec if err != nil { return nil, fmt.Errorf("failed to parse current user uid: %w", err) } + proc, err := procfs.NewDefaultFS() + if err != nil { + return nil, fmt.Errorf("failed to initialize procfs: %w", err) + } + connectionTracker := connections.NewConnectionTracker(connections.ConnectionTrackerConfig{ + Port: uint64(sshPort), + MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server + Procfs: proc, + }) return &RunExecutor{ tempDir: tempDir, homeDir: homeDir, @@ -71,7 +83,8 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExec runnerLogs: newAppendWriter(mu, timestamp), timestamp: timestamp, - killDelay: 10 * time.Second, + killDelay: 10 * time.Second, + connectionTracker: connectionTracker, }, nil } @@ -130,6 +143,10 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { } defer cleanupCredentials() + connectionTrackerTicker := time.NewTicker(2500 * time.Millisecond) + go ex.connectionTracker.Track(connectionTrackerTicker.C) + defer ex.connectionTracker.Stop() + ex.SetJobState(ctx, types.JobStateRunning) timeoutCtx := ctx var cancelTimeout context.CancelFunc diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index b62b3061f..ba5113c4f 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -169,7 +169,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor { _ = os.Mkdir(home, 0o700) repo := filepath.Join(baseDir, "repo") _ = os.Mkdir(repo, 0o700) - ex, _ := NewRunExecutor(temp, home, repo) + ex, _ := NewRunExecutor(temp, home, repo, 10022) ex.SetJob(body) ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run return ex diff --git a/runner/internal/executor/query.go b/runner/internal/executor/query.go index 6e7f2d288..1dff4e330 100644 --- a/runner/internal/executor/query.go +++ b/runner/internal/executor/query.go @@ -10,11 +10,12 @@ func (ex *RunExecutor) GetJobLogsHistory() []schemas.LogEvent { func (ex *RunExecutor) GetHistory(timestamp int64) *schemas.PullResponse { return &schemas.PullResponse{ - JobStates: eventsAfter(ex.jobStateHistory, timestamp), - JobLogs: eventsAfter(ex.jobLogs.history, timestamp), - RunnerLogs: eventsAfter(ex.runnerLogs.history, timestamp), - LastUpdated: ex.timestamp.GetLatest(), - HasMore: ex.state != WaitLogsFinished, + JobStates: eventsAfter(ex.jobStateHistory, timestamp), + JobLogs: eventsAfter(ex.jobLogs.history, timestamp), + RunnerLogs: eventsAfter(ex.runnerLogs.history, timestamp), + LastUpdated: ex.timestamp.GetLatest(), + NoConnectionsSecs: ex.connectionTracker.GetNoConnectionsSecs(), + HasMore: ex.state != WaitLogsFinished, } } diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 80dfccc47..bc1b8cc93 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -34,9 +34,9 @@ type Server struct { version string } -func NewServer(tempDir string, homeDir string, workingDir string, address string, version string) (*Server, error) { +func NewServer(tempDir string, homeDir string, workingDir string, address string, sshPort int, version string) (*Server, error) { r := api.NewRouter() - ex, err := executor.NewRunExecutor(tempDir, homeDir, workingDir) + ex, err := executor.NewRunExecutor(tempDir, homeDir, workingDir, sshPort) if err != nil { return nil, err } diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 2626bbdea..8d370e253 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -27,11 +27,12 @@ type SubmitBody struct { } type PullResponse struct { - JobStates []JobStateEvent `json:"job_states"` - JobLogs []LogEvent `json:"job_logs"` - RunnerLogs []LogEvent `json:"runner_logs"` - LastUpdated int64 `json:"last_updated"` - HasMore bool `json:"has_more"` + JobStates []JobStateEvent `json:"job_states"` + JobLogs []LogEvent `json:"job_logs"` + RunnerLogs []LogEvent `json:"runner_logs"` + LastUpdated int64 `json:"last_updated"` + NoConnectionsSecs int64 `json:"no_connections_secs"` + HasMore bool `json:"has_more"` // todo Result } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 67940289b..7b6bfb3d7 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -903,7 +903,17 @@ func getSSHShellCommands(openSSHPort int, publicSSHKey string) []string { "rm -rf /run/sshd && mkdir -p /run/sshd && chown root:root /run/sshd", "rm -rf /var/empty && mkdir -p /var/empty && chown root:root /var/empty", // start sshd - fmt.Sprintf("/usr/sbin/sshd -p %d -o PidFile=none -o PasswordAuthentication=no -o AllowTcpForwarding=yes -o PermitUserEnvironment=yes", openSSHPort), + fmt.Sprintf( + "/usr/sbin/sshd"+ + " -p %d"+ + " -o PidFile=none"+ + " -o PasswordAuthentication=no"+ + " -o AllowTcpForwarding=yes"+ + " -o PermitUserEnvironment=yes"+ + " -o ClientAliveInterval=30"+ + " -o ClientAliveCountMax=4", + openSSHPort, + ), // restore ld.so variables `if [ -n "$_LD_LIBRARY_PATH" ]; then export LD_LIBRARY_PATH="$_LD_LIBRARY_PATH"; fi`, `if [ -n "$_LD_PRELOAD" ]; then export LD_PRELOAD="$_LD_PRELOAD"; fi`, diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go index 56dd2b1a7..e044cfcce 100644 --- a/runner/internal/shim/runner.go +++ b/runner/internal/shim/runner.go @@ -31,6 +31,7 @@ func (c *CLIArgs) getRunnerArgs() []string { "--log-level", strconv.Itoa(c.Runner.LogLevel), "start", "--http-port", strconv.Itoa(c.Runner.HTTPPort), + "--ssh-port", strconv.Itoa(c.Runner.SSHPort), "--temp-dir", consts.RunnerTempDir, "--home-dir", consts.RunnerHomeDir, "--working-dir", consts.RunnerWorkingDir, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 8cf479684..eb6141856 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -371,7 +371,16 @@ def get_docker_commands( "rm -rf /run/sshd && mkdir -p /run/sshd && chown root:root /run/sshd", "rm -rf /var/empty && mkdir -p /var/empty && chown root:root /var/empty", # start sshd - f"/usr/sbin/sshd -p {DSTACK_RUNNER_SSH_PORT} -o PidFile=none -o PasswordAuthentication=no -o AllowTcpForwarding=yes -o PermitUserEnvironment=yes", + ( + "/usr/sbin/sshd" + f" -p {DSTACK_RUNNER_SSH_PORT}" + " -o PidFile=none" + " -o PasswordAuthentication=no" + " -o AllowTcpForwarding=yes" + " -o PermitUserEnvironment=yes" + " -o ClientAliveInterval=30" + " -o ClientAliveCountMax=4" + ), # restore ld.so variables 'if [ -n "$_LD_LIBRARY_PATH" ]; then export LD_LIBRARY_PATH="$_LD_LIBRARY_PATH"; fi', 'if [ -n "$_LD_PRELOAD" ]; then export LD_PRELOAD="$_LD_PRELOAD"; fi', @@ -381,7 +390,16 @@ def get_docker_commands( commands += [ f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {DSTACK_RUNNER_BINARY_PATH} {url}", f"chmod +x {DSTACK_RUNNER_BINARY_PATH}", - f"{DSTACK_RUNNER_BINARY_PATH} --log-level 6 start --http-port {DSTACK_RUNNER_HTTP_PORT} --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", + ( + f"{DSTACK_RUNNER_BINARY_PATH}" + " --log-level 6" + " start" + f" --http-port {DSTACK_RUNNER_HTTP_PORT}" + f" --ssh-port {DSTACK_RUNNER_SSH_PORT}" + " --temp-dir /tmp/runner" + " --home-dir /root" + " --working-dir /workflow" + ), ] return commands diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index f420fd2eb..c8736c7eb 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -30,7 +30,7 @@ from dstack._internal.utils.common import get_or_error from dstack._internal.utils.logging import get_logger -REQUEST_TIMEOUT = 15 +REQUEST_TIMEOUT = 9 logger = get_logger(__name__) diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 044c5d3bc..b66f81fd1 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -25,6 +25,14 @@ def runner_ssh_tunnel( [Callable[Concatenate[Dict[int, int], P], R]], Callable[Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R]], ]: + """ + A decorator that opens an SSH tunnel to the runner. + + NOTE: connections from dstack-server to running jobs are expected to be short. + The runner uses a heuristic to differentiate dstack-server connections from + client connections based on their duration. See `ConnectionTracker` for details. + """ + def decorator( func: Callable[Concatenate[Dict[int, int], P], R], ) -> Callable[