From 930aca0a5dbb9bcfae0ce7175d6a391d099aba2a Mon Sep 17 00:00:00 2001 From: Patrick Dawkins Date: Sat, 28 Dec 2024 19:31:10 +0000 Subject: [PATCH] Refactor and update deps --- go-tests/go.mod | 4 +- go-tests/go.sum | 12 ++-- go-tests/mockssh/server.go | 120 +++++++++++++++++++++++-------------- go-tests/ssh_test.go | 36 +++++------ 4 files changed, 101 insertions(+), 71 deletions(-) diff --git a/go-tests/go.mod b/go-tests/go.mod index c0c1492f8..5dbcbf328 100644 --- a/go-tests/go.mod +++ b/go-tests/go.mod @@ -5,6 +5,7 @@ go 1.22.9 require ( github.com/platformsh/cli v0.0.0-20241227091635-fea73d95f802 github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.31.0 ) require ( @@ -13,7 +14,6 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/oklog/ulid/v2 v2.1.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - golang.org/x/crypto v0.24.0 // indirect - golang.org/x/sys v0.21.0 // indirect + golang.org/x/sys v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go-tests/go.sum b/go-tests/go.sum index cc868ae66..caf6f2b97 100644 --- a/go-tests/go.sum +++ b/go-tests/go.sum @@ -19,12 +19,12 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/go-tests/mockssh/server.go b/go-tests/mockssh/server.go index e32ce87a4..558988e4e 100644 --- a/go-tests/mockssh/server.go +++ b/go-tests/mockssh/server.go @@ -5,7 +5,10 @@ import ( "encoding/base64" "errors" "fmt" + "io" "net" + "os" + "os/exec" "sync" "testing" "time" @@ -16,53 +19,55 @@ import ( //go:embed host_key var hostKey []byte -type MockSSHServer struct { - t *testing.T - commandHandler CommandHandler - listener net.Listener - port int - hostKey ssh.Signer - certChecker ssh.CertChecker +type Server struct { + // t and hostKey are set in NewServer. + t *testing.T + hostKey ssh.Signer + + // RemoteEnv, RemoteDir, CertChecker and CommandHandler are optional configuration. + RemoteEnv []string + RemoteDir string + CertChecker ssh.CertChecker + CommandHandler CommandHandler + + // listener and port are set after Start. + listener net.Listener + port int } -type CommandHandler func(conn ssh.ConnMetadata, command string) (string, string, uint32) +type CommandIO struct { + StdIn io.Reader + StdOut io.Writer + StdErr io.Writer +} -// NewStartedServer creates and starts a local SSH server for a test. -// The server will automatically be stopped when the test completes. -func NewStartedServer(t *testing.T, handler CommandHandler, certChecker ssh.CertChecker) (*MockSSHServer, error) { +type CommandHandler func(conn ssh.ConnMetadata, command string, io CommandIO) uint32 + +// NewServer creates and starts a local SSH server for a test. +// It must be stopped with the Server.Stop method. +func NewServer(t *testing.T) (*Server, error) { hk, err := ssh.ParsePrivateKey(hostKey) if err != nil { return nil, fmt.Errorf("failed to parse host key: %v", err) } - if certChecker.IsUserAuthority == nil { - return nil, fmt.Errorf("cert checker must define IsUserAuthority") - } + s := &Server{t: t, hostKey: hk} + s.CommandHandler = s.defaultCommandHandler + s.CertChecker = s.defaultCertChecker() + s.RemoteDir = t.TempDir() - s := &MockSSHServer{ - t: t, - commandHandler: handler, - hostKey: hk, - certChecker: certChecker, - } if err := s.start(); err != nil { return nil, err } - t.Cleanup(func() { - if err := s.listener.Close(); err != nil { - t.Fatal(err) - } - }) - return s, nil } -func (s *MockSSHServer) Port() int { +func (s *Server) Port() int { return s.port } -func (s *MockSSHServer) HostKeyConfig() string { +func (s *Server) HostKeyConfig() string { return fmt.Sprintf("[127.0.0.1]:%d %s %s", s.port, s.hostKey.PublicKey().Type(), @@ -70,7 +75,7 @@ func (s *MockSSHServer) HostKeyConfig() string { ) } -func (s *MockSSHServer) start() error { +func (s *Server) start() error { t := s.t config := s.serverConfig() @@ -120,14 +125,46 @@ func (s *MockSSHServer) start() error { return nil } -func (s *MockSSHServer) serverConfig() *ssh.ServerConfig { +func (s *Server) Stop() error { + if s.listener != nil { + return s.listener.Close() + } + return nil +} + +func (s *Server) defaultCommandHandler(_ ssh.ConnMetadata, command string, io CommandIO) uint32 { + c := exec.Command("bash", "-c", command) + c.Stdout = io.StdOut + c.Stderr = io.StdErr + c.Stdin = io.StdIn + c.Dir = s.RemoteDir + c.Env = append(os.Environ(), s.RemoteEnv...) + if err := c.Run(); err != nil { + exitErr := &exec.ExitError{} + if errors.As(err, &exitErr) { + return uint32(exitErr.ExitCode()) + } + _, _ = io.StdErr.Write([]byte(fmt.Sprintf("Failed to execute command: %v\n", err))) + return 1 + } + return 0 +} + +func (s *Server) defaultCertChecker() ssh.CertChecker { + return ssh.CertChecker{IsUserAuthority: func(auth ssh.PublicKey) bool { + s.t.Log("No CertChecker defined; rejecting certificate") + return false + }} +} + +func (s *Server) serverConfig() *ssh.ServerConfig { t := s.t conf := &ssh.ServerConfig{} conf.AddHostKey(s.hostKey) conf.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if cert, ok := key.(*ssh.Certificate); ok { t.Logf("SSH certificate received from %s with key ID %s", conn.RemoteAddr(), cert.KeyId) - return s.certChecker.Authenticate(conn, cert) + return s.CertChecker.Authenticate(conn, cert) } return nil, fmt.Errorf("not accepting public key type: %s", key.Type()) } @@ -144,7 +181,7 @@ func (s *MockSSHServer) serverConfig() *ssh.ServerConfig { return conf } -func (s *MockSSHServer) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChannel) { +func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChannel) { t := s.t for newChannel := range chans { @@ -180,20 +217,11 @@ func (s *MockSSHServer) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.N // See https://datatracker.ietf.org/doc/html/rfc4251#section-5 cmd := req.Payload[4:] t.Logf("Handling command: %s", cmd) - stdOut, stdErr, exitCode := s.commandHandler(conn, string(cmd)) - if len(stdOut) > 0 { - _, err = channel.Write([]byte(stdOut)) - if err != nil { - t.Errorf("Failed to write to stdout channel: %v", err) - } - } - if len(stdErr) > 0 { - _, err = channel.Stderr().Write([]byte(stdErr)) - if err != nil { - t.Errorf("Failed to write to stderr channel: %v", err) - } - } - exitWithStatus <- exitCode + exitWithStatus <- s.CommandHandler(conn, string(cmd), CommandIO{ + StdIn: channel, + StdOut: channel, + StdErr: channel.Stderr(), + }) return default: _ = req.Reply(false, nil) diff --git a/go-tests/ssh_test.go b/go-tests/ssh_test.go index 94c158c40..2ac89d6e2 100644 --- a/go-tests/ssh_test.go +++ b/go-tests/ssh_test.go @@ -1,7 +1,6 @@ package tests import ( - "fmt" "net/http/httptest" "os/exec" "strconv" @@ -19,22 +18,24 @@ func TestSSH(t *testing.T) { myUserID := "my-user-id" - sshServer, err := mockssh.NewStartedServer(t, func(conn ssh.ConnMetadata, command string) (string, string, uint32) { - switch command { - case "pwd": - return "/mock/path", "", 0 - case "fail-with-code-2": - return "", "Returning exit code 2\n", 2 - } - return "", fmt.Sprintf("Unknown command: %s\n", command), 1 - }, ssh.CertChecker{IsUserAuthority: func(auth ssh.PublicKey) bool { - // TODO use the auth server's keys - t.Logf("checking if key is user authority: %s", ssh.MarshalAuthorizedKey(auth)) - return true - }}) + sshServer, err := mockssh.NewServer(t) if err != nil { t.Fatal(err) } + sshServer.CertChecker = ssh.CertChecker{IsUserAuthority: func(auth ssh.PublicKey) bool { + // TODO use the auth server's keys + t.Logf("checking if key is user authority (returning true for anything, for now), key: %s", ssh.MarshalAuthorizedKey(auth)) + return true + }} + sshServer.RemoteEnv = []string{ + // TODO use this + "PLATFORM_RELATIONSHIPS=e30K", + } + t.Cleanup(func() { + if err := sshServer.Stop(); err != nil { + t.Error(err) + } + }) projectID := "aiyaikii1uere" @@ -76,12 +77,13 @@ func TestSSH(t *testing.T) { } f.Run("cc") - assert.Equal(t, "/mock/path", f.Run("ssh", "-p", projectID, "-e", ".", "pwd")) + assert.Equal(t, sshServer.RemoteDir+"\n", f.Run("ssh", "-p", projectID, "-e", ".", "pwd")) - _, stdErr, _ := f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "2", "pwd") + _, stdErr, err := f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "2", "pwd") + assert.Error(t, err) assert.Contains(t, stdErr, "Available instances: 0, 1") - _, _, err = f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "1", "fail-with-code-2") + _, _, err = f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "1", "exit 2") var exitErr *exec.ExitError assert.ErrorAs(t, err, &exitErr) assert.Equal(t, 2, exitErr.ExitCode())