Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track SSH connections in dstack-runner #2287

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion runner/cmd/runner/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ var Version string
func App() {
var paths struct{ tempDir, homeDir, workingDir string }
var httpPort int
var sshPort int
var logLevel int

app := &cli.App{
Expand Down Expand Up @@ -57,9 +58,15 @@ func App() {
Value: 10999,
Destination: &httpPort,
},
&cli.IntFlag{
Name: "ssh-port",
Usage: "Set the ssh port",
Required: true,
Destination: &sshPort,
},
},
Action: func(c *cli.Context) error {
err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel, Version)
err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, sshPort, logLevel, Version)
if err != nil {
return cli.Exit(err, 1)
}
Expand Down
4 changes: 2 additions & 2 deletions runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func main() {
App()
}

func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int, version string) error {
func start(tempDir string, homeDir string, workingDir string, httpPort int, sshPort int, logLevel int, version string) error {
if err := os.MkdirAll(tempDir, 0o755); err != nil {
return tracerr.Errorf("Failed to create temp directory: %w", err)
}
Expand All @@ -38,7 +38,7 @@ func start(tempDir string, homeDir string, workingDir string, httpPort int, logL
log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile))
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))

server, err := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort), version)
server, err := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort), sshPort, version)
if err != nil {
return tracerr.Errorf("Failed to create server: %w", err)
}
Expand Down
1 change: 1 addition & 0 deletions runner/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ require (
require (
github.com/codeclysm/extract/v3 v3.1.1
github.com/gorilla/websocket v1.5.1
github.com/prometheus/procfs v0.15.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions runner/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
Expand Down
130 changes: 130 additions & 0 deletions runner/internal/connections/connections.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package connections

import (
"context"
"fmt"
"sync"
"time"

"github.com/dstackai/dstack/runner/internal/log"
"github.com/prometheus/procfs"
)

const connStateEstablished = 1

type connection struct {
fromAddr string
fromPort uint64
}

type trackingInfo struct {
firstSeenAt time.Time
}

type ConnectionTrackerConfig struct {
Port uint64
MinConnDuration time.Duration
Procfs procfs.FS
}

// Tracks TCP connections to a specified port.
type ConnectionTracker struct {
cfg ConnectionTrackerConfig
connections map[connection]trackingInfo
lastConnectionAt *time.Time
lastCheckedAt *time.Time
stopChan chan struct{}
mu sync.RWMutex
}

func NewConnectionTracker(cfg ConnectionTrackerConfig) *ConnectionTracker {
tracker := ConnectionTracker{
cfg: cfg,
connections: make(map[connection]trackingInfo),
lastConnectionAt: nil,
lastCheckedAt: nil,
stopChan: make(chan struct{}),
mu: sync.RWMutex{},
}
return &tracker
}

// Returns the number of seconds since the last connection was closed or
// since tracking started. If tracking hasn't started yet, returns 0.
func (t *ConnectionTracker) GetNoConnectionsSecs() int64 {
t.mu.RLock()
defer t.mu.RUnlock()
if t.lastConnectionAt == nil || t.lastCheckedAt == nil {
return 0
}
return int64(t.lastCheckedAt.Sub(*t.lastConnectionAt).Seconds())
}

func (t *ConnectionTracker) Track(ticker <-chan time.Time) {
for {
select {
case now := <-ticker:
t.updateConnections(now)
case <-t.stopChan:
return
}
}
}

func (t *ConnectionTracker) Stop() {
t.stopChan <- struct{}{}
}

func (t *ConnectionTracker) updateConnections(now time.Time) {
currentConnections, err := t.getCurrentConnections()
if err != nil {
log.Error(context.TODO(), "Failed to retrieve connections: %v", err)
return
}
t.mu.Lock()
defer t.mu.Unlock()
// evict closed connections
for conn := range t.connections {
if _, ok := currentConnections[conn]; !ok {
delete(t.connections, conn)
}
}
// add new connections
for conn := range currentConnections {
if _, ok := t.connections[conn]; !ok {
t.connections[conn] = trackingInfo{firstSeenAt: now}
}
}
// update lastConnectionAt
for _, connInfo := range t.connections {
if now.Sub(connInfo.firstSeenAt) > t.cfg.MinConnDuration {
t.lastConnectionAt = &now
break
}
}
if t.lastConnectionAt == nil { // first call to updateConnections
t.lastConnectionAt = &now
}
t.lastCheckedAt = &now
}

func (t *ConnectionTracker) getCurrentConnections() (map[connection]struct{}, error) {
connections := make(map[connection]struct{})
netTCP, err := t.cfg.Procfs.NetTCP()
if err != nil {
return nil, fmt.Errorf("Failed to retrieve IPv4 network connections: %w", err)
}
netTCP6, err := t.cfg.Procfs.NetTCP6()
if err != nil {
return nil, fmt.Errorf("Failed to retrieve IPv6 network connections: %w", err)
}
for _, conn := range append(netTCP, netTCP6...) {
if conn.LocalPort == t.cfg.Port && conn.St == connStateEstablished {
connections[connection{
fromAddr: conn.RemAddr.String(),
fromPort: conn.RemPort,
}] = struct{}{}
}
}
return connections, nil
}
94 changes: 94 additions & 0 deletions runner/internal/connections/connections_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package connections

import (
"io/fs"
"os"
"testing"
"time"

"github.com/prometheus/procfs"
"github.com/stretchr/testify/assert"
)

func TestConnectionTracker(t *testing.T) {
procfsDir := t.TempDir()
proc, err := procfs.NewFS(procfsDir)
assert.NoError(t, err)
err = os.Mkdir(procfsDir+"/net", os.ModePerm)
assert.NoError(t, err)
tracker := NewConnectionTracker(ConnectionTrackerConfig{
Port: 4096,
MinConnDuration: 5 * time.Second,
Procfs: proc,
})
ticker := make(chan time.Time)
// Open sockets on ports 53 and 4096 + established connection to port 53 (irrelevant)
noConnTcp := ` sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 3500007F:0035 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0
1: 00000000:1000 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0
2: 3500007F:0035 0100007F:1234 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0
`
noConnTcp6 := ` sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode
0: 00000000000000000000000000000000:0035 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0
1: 00000000000000000000000000000000:1000 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0
2: 00000000000000000000000000000000:0035 00000000000000000000000001000000:1234 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0
`
// Established connection to port 4096 (relevant)
connTcp := " 3: 00000000:1000 0100007F:4321 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0"
connTcp6 := " 3: 00000000000000000000000000000000:1000 00000000000000000000000001000000:4321 01 00000000:00000000 00:00000000 00000000 0 0 12345 1 0000000000000000 100 0 0 10 0"

// Tracking did not start yet
// Returns 0 secs
assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs())

go tracker.Track(ticker)
defer tracker.Stop()
assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs())

// There is a 2-second-old connection
// Returns 2 secs (the connection doesn't count as it's < MinConnDuration)
writeProcfs(t, procfsDir, noConnTcp+connTcp, noConnTcp6)
tick := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
ticker <- tick
wait()
tick = tick.Add(2 * time.Second)
ticker <- tick
wait()
assert.Equal(t, int64(2), tracker.GetNoConnectionsSecs())

// There is a 6-second-old connection
// Returns 0 secs (the connection is >= MinConnDuration)
tick = tick.Add(4 * time.Second)
ticker <- tick
wait()
assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs())

// The connection is closed and there are no connections for 15 secs.
// Returns 15 secs
writeProcfs(t, procfsDir, noConnTcp, noConnTcp6)
tick = tick.Add(15 * time.Second)
ticker <- tick
wait()
assert.Equal(t, int64(15), tracker.GetNoConnectionsSecs())

// There is a 7-second-old connection over IPv6
// Returns 0 secs (the connection is >= MinConnDuration)
writeProcfs(t, procfsDir, noConnTcp, noConnTcp6+connTcp6)
tick = tick.Add(1 * time.Second)
ticker <- tick
tick = tick.Add(7 * time.Second)
ticker <- tick
wait()
assert.Equal(t, int64(0), tracker.GetNoConnectionsSecs())
}

func writeProcfs(t *testing.T, procfsDir, tcp, tcp6 string) {
err := os.WriteFile(procfsDir+"/net/tcp", []byte(tcp), fs.ModePerm)
assert.NoError(t, err)
err = os.WriteFile(procfsDir+"/net/tcp6", []byte(tcp6), fs.ModePerm)
assert.NoError(t, err)
}

func wait() {
time.Sleep(30 * time.Millisecond)
}
23 changes: 20 additions & 3 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import (

"github.com/creack/pty"
"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/internal/connections"
"github.com/dstackai/dstack/runner/internal/gerrors"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/schemas"
"github.com/dstackai/dstack/runner/internal/types"
"github.com/prometheus/procfs"
)

type RunExecutor struct {
Expand All @@ -44,10 +46,11 @@ type RunExecutor struct {
runnerLogs *appendWriter
timestamp *MonotonicTimestamp

killDelay time.Duration
killDelay time.Duration
connectionTracker *connections.ConnectionTracker
}

func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExecutor, error) {
func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort int) (*RunExecutor, error) {
mu := &sync.RWMutex{}
timestamp := NewMonotonicTimestamp()
user, err := osuser.Current()
Expand All @@ -58,6 +61,15 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExec
if err != nil {
return nil, fmt.Errorf("failed to parse current user uid: %w", err)
}
proc, err := procfs.NewDefaultFS()
if err != nil {
return nil, fmt.Errorf("failed to initialize procfs: %w", err)
}
connectionTracker := connections.NewConnectionTracker(connections.ConnectionTrackerConfig{
Port: uint64(sshPort),
MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server
Procfs: proc,
})
return &RunExecutor{
tempDir: tempDir,
homeDir: homeDir,
Expand All @@ -71,7 +83,8 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExec
runnerLogs: newAppendWriter(mu, timestamp),
timestamp: timestamp,

killDelay: 10 * time.Second,
killDelay: 10 * time.Second,
connectionTracker: connectionTracker,
}, nil
}

Expand Down Expand Up @@ -130,6 +143,10 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
}
defer cleanupCredentials()

connectionTrackerTicker := time.NewTicker(2500 * time.Millisecond)
go ex.connectionTracker.Track(connectionTrackerTicker.C)
defer ex.connectionTracker.Stop()

ex.SetJobState(ctx, types.JobStateRunning)
timeoutCtx := ctx
var cancelTimeout context.CancelFunc
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
_ = os.Mkdir(home, 0o700)
repo := filepath.Join(baseDir, "repo")
_ = os.Mkdir(repo, 0o700)
ex, _ := NewRunExecutor(temp, home, repo)
ex, _ := NewRunExecutor(temp, home, repo, 10022)
ex.SetJob(body)
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
return ex
Expand Down
11 changes: 6 additions & 5 deletions runner/internal/executor/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ func (ex *RunExecutor) GetJobLogsHistory() []schemas.LogEvent {

func (ex *RunExecutor) GetHistory(timestamp int64) *schemas.PullResponse {
return &schemas.PullResponse{
JobStates: eventsAfter(ex.jobStateHistory, timestamp),
JobLogs: eventsAfter(ex.jobLogs.history, timestamp),
RunnerLogs: eventsAfter(ex.runnerLogs.history, timestamp),
LastUpdated: ex.timestamp.GetLatest(),
HasMore: ex.state != WaitLogsFinished,
JobStates: eventsAfter(ex.jobStateHistory, timestamp),
JobLogs: eventsAfter(ex.jobLogs.history, timestamp),
RunnerLogs: eventsAfter(ex.runnerLogs.history, timestamp),
LastUpdated: ex.timestamp.GetLatest(),
NoConnectionsSecs: ex.connectionTracker.GetNoConnectionsSecs(),
HasMore: ex.state != WaitLogsFinished,
}
}

Expand Down
4 changes: 2 additions & 2 deletions runner/internal/runner/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ type Server struct {
version string
}

func NewServer(tempDir string, homeDir string, workingDir string, address string, version string) (*Server, error) {
func NewServer(tempDir string, homeDir string, workingDir string, address string, sshPort int, version string) (*Server, error) {
r := api.NewRouter()
ex, err := executor.NewRunExecutor(tempDir, homeDir, workingDir)
ex, err := executor.NewRunExecutor(tempDir, homeDir, workingDir, sshPort)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading