Skip to content

Commit

Permalink
Refactor and update deps
Browse files Browse the repository at this point in the history
  • Loading branch information
pjcdawkins committed Dec 28, 2024
1 parent 95446bf commit 930aca0
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 71 deletions.
4 changes: 2 additions & 2 deletions go-tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.22.9
require (
github.com/platformsh/cli v0.0.0-20241227091635-fea73d95f802
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.31.0
)

require (
Expand All @@ -13,7 +14,6 @@ require (
github.com/kr/pretty v0.3.1 // indirect
github.com/oklog/ulid/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
golang.org/x/crypto v0.24.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/sys v0.28.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
12 changes: 6 additions & 6 deletions go-tests/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
Expand Down
120 changes: 74 additions & 46 deletions go-tests/mockssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"sync"
"testing"
"time"
Expand All @@ -16,61 +19,63 @@ import (
//go:embed host_key
var hostKey []byte

type MockSSHServer struct {
t *testing.T
commandHandler CommandHandler
listener net.Listener
port int
hostKey ssh.Signer
certChecker ssh.CertChecker
type Server struct {
// t and hostKey are set in NewServer.
t *testing.T
hostKey ssh.Signer

// RemoteEnv, RemoteDir, CertChecker and CommandHandler are optional configuration.
RemoteEnv []string
RemoteDir string
CertChecker ssh.CertChecker
CommandHandler CommandHandler

// listener and port are set after Start.
listener net.Listener
port int
}

type CommandHandler func(conn ssh.ConnMetadata, command string) (string, string, uint32)
type CommandIO struct {
StdIn io.Reader
StdOut io.Writer
StdErr io.Writer
}

// NewStartedServer creates and starts a local SSH server for a test.
// The server will automatically be stopped when the test completes.
func NewStartedServer(t *testing.T, handler CommandHandler, certChecker ssh.CertChecker) (*MockSSHServer, error) {
type CommandHandler func(conn ssh.ConnMetadata, command string, io CommandIO) uint32

// NewServer creates and starts a local SSH server for a test.
// It must be stopped with the Server.Stop method.
func NewServer(t *testing.T) (*Server, error) {
hk, err := ssh.ParsePrivateKey(hostKey)
if err != nil {
return nil, fmt.Errorf("failed to parse host key: %v", err)
}

if certChecker.IsUserAuthority == nil {
return nil, fmt.Errorf("cert checker must define IsUserAuthority")
}
s := &Server{t: t, hostKey: hk}
s.CommandHandler = s.defaultCommandHandler
s.CertChecker = s.defaultCertChecker()
s.RemoteDir = t.TempDir()

s := &MockSSHServer{
t: t,
commandHandler: handler,
hostKey: hk,
certChecker: certChecker,
}
if err := s.start(); err != nil {
return nil, err
}

t.Cleanup(func() {
if err := s.listener.Close(); err != nil {
t.Fatal(err)
}
})

return s, nil
}

func (s *MockSSHServer) Port() int {
func (s *Server) Port() int {
return s.port
}

func (s *MockSSHServer) HostKeyConfig() string {
func (s *Server) HostKeyConfig() string {
return fmt.Sprintf("[127.0.0.1]:%d %s %s",
s.port,
s.hostKey.PublicKey().Type(),
base64.StdEncoding.EncodeToString(s.hostKey.PublicKey().Marshal()),
)
}

func (s *MockSSHServer) start() error {
func (s *Server) start() error {
t := s.t

config := s.serverConfig()
Expand Down Expand Up @@ -120,14 +125,46 @@ func (s *MockSSHServer) start() error {
return nil
}

func (s *MockSSHServer) serverConfig() *ssh.ServerConfig {
func (s *Server) Stop() error {
if s.listener != nil {
return s.listener.Close()
}
return nil
}

func (s *Server) defaultCommandHandler(_ ssh.ConnMetadata, command string, io CommandIO) uint32 {
c := exec.Command("bash", "-c", command)
c.Stdout = io.StdOut
c.Stderr = io.StdErr
c.Stdin = io.StdIn
c.Dir = s.RemoteDir
c.Env = append(os.Environ(), s.RemoteEnv...)
if err := c.Run(); err != nil {
exitErr := &exec.ExitError{}
if errors.As(err, &exitErr) {
return uint32(exitErr.ExitCode())
}
_, _ = io.StdErr.Write([]byte(fmt.Sprintf("Failed to execute command: %v\n", err)))
return 1
}
return 0
}

func (s *Server) defaultCertChecker() ssh.CertChecker {
return ssh.CertChecker{IsUserAuthority: func(auth ssh.PublicKey) bool {
s.t.Log("No CertChecker defined; rejecting certificate")
return false
}}
}

func (s *Server) serverConfig() *ssh.ServerConfig {
t := s.t
conf := &ssh.ServerConfig{}
conf.AddHostKey(s.hostKey)
conf.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if cert, ok := key.(*ssh.Certificate); ok {
t.Logf("SSH certificate received from %s with key ID %s", conn.RemoteAddr(), cert.KeyId)
return s.certChecker.Authenticate(conn, cert)
return s.CertChecker.Authenticate(conn, cert)
}
return nil, fmt.Errorf("not accepting public key type: %s", key.Type())
}
Expand All @@ -144,7 +181,7 @@ func (s *MockSSHServer) serverConfig() *ssh.ServerConfig {
return conf
}

func (s *MockSSHServer) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChannel) {
func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChannel) {
t := s.t

for newChannel := range chans {
Expand Down Expand Up @@ -180,20 +217,11 @@ func (s *MockSSHServer) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.N
// See https://datatracker.ietf.org/doc/html/rfc4251#section-5
cmd := req.Payload[4:]
t.Logf("Handling command: %s", cmd)
stdOut, stdErr, exitCode := s.commandHandler(conn, string(cmd))
if len(stdOut) > 0 {
_, err = channel.Write([]byte(stdOut))
if err != nil {
t.Errorf("Failed to write to stdout channel: %v", err)
}
}
if len(stdErr) > 0 {
_, err = channel.Stderr().Write([]byte(stdErr))
if err != nil {
t.Errorf("Failed to write to stderr channel: %v", err)
}
}
exitWithStatus <- exitCode
exitWithStatus <- s.CommandHandler(conn, string(cmd), CommandIO{
StdIn: channel,
StdOut: channel,
StdErr: channel.Stderr(),
})
return
default:
_ = req.Reply(false, nil)
Expand Down
36 changes: 19 additions & 17 deletions go-tests/ssh_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tests

import (
"fmt"
"net/http/httptest"
"os/exec"
"strconv"
Expand All @@ -19,22 +18,24 @@ func TestSSH(t *testing.T) {

myUserID := "my-user-id"

sshServer, err := mockssh.NewStartedServer(t, func(conn ssh.ConnMetadata, command string) (string, string, uint32) {
switch command {
case "pwd":
return "/mock/path", "", 0
case "fail-with-code-2":
return "", "Returning exit code 2\n", 2
}
return "", fmt.Sprintf("Unknown command: %s\n", command), 1
}, ssh.CertChecker{IsUserAuthority: func(auth ssh.PublicKey) bool {
// TODO use the auth server's keys
t.Logf("checking if key is user authority: %s", ssh.MarshalAuthorizedKey(auth))
return true
}})
sshServer, err := mockssh.NewServer(t)
if err != nil {
t.Fatal(err)
}
sshServer.CertChecker = ssh.CertChecker{IsUserAuthority: func(auth ssh.PublicKey) bool {
// TODO use the auth server's keys
t.Logf("checking if key is user authority (returning true for anything, for now), key: %s", ssh.MarshalAuthorizedKey(auth))
return true
}}
sshServer.RemoteEnv = []string{
// TODO use this
"PLATFORM_RELATIONSHIPS=e30K",
}
t.Cleanup(func() {
if err := sshServer.Stop(); err != nil {
t.Error(err)
}
})

projectID := "aiyaikii1uere"

Expand Down Expand Up @@ -76,12 +77,13 @@ func TestSSH(t *testing.T) {
}

f.Run("cc")
assert.Equal(t, "/mock/path", f.Run("ssh", "-p", projectID, "-e", ".", "pwd"))
assert.Equal(t, sshServer.RemoteDir+"\n", f.Run("ssh", "-p", projectID, "-e", ".", "pwd"))

_, stdErr, _ := f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "2", "pwd")
_, stdErr, err := f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "2", "pwd")
assert.Error(t, err)
assert.Contains(t, stdErr, "Available instances: 0, 1")

_, _, err = f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "1", "fail-with-code-2")
_, _, err = f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "1", "exit 2")
var exitErr *exec.ExitError
assert.ErrorAs(t, err, &exitErr)
assert.Equal(t, 2, exitErr.ExitCode())
Expand Down

0 comments on commit 930aca0

Please sign in to comment.