From 817ff52180ecff27e21e491f9e7df6179ed65af5 Mon Sep 17 00:00:00 2001 From: Disco Date: Sat, 11 Dec 2021 16:06:59 +0000 Subject: [PATCH 1/3] Moved job methods to more appropriate location. --- config.go | 24 ------------------------ job.go | 30 +++++++++++++++++++++--------- session.go | 12 ++++++++++++ 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/config.go b/config.go index dc6d36f..677927e 100644 --- a/config.go +++ b/config.go @@ -26,13 +26,6 @@ type Config struct { BastionHostSSHConfig *ssh.ClientConfig } -// Job is a single remote task config. For script files, use Job.SetLocalScript(). -type Job struct { - Command string - Script []byte - ScriptArgs string -} - // NewConfig initialises a new massh.Config. func NewConfig() *Config { c := &Config{ @@ -173,25 +166,8 @@ func (c *Config) SetPrivateKeyAuth(PrivateKeyFile string, PrivateKeyPassphrase s return nil } -// SetCommand sets the Command value in Job. This is the Command executed over SSH to all hosts. -func (j *Job) SetCommand(command string) { - j.Command = command -} - // SetPasswordAuth sets ssh password from provided byte slice (read from terminal) func (c *Config) SetPasswordAuth(username string, password string) { c.SSHConfig.User = username c.SSHConfig.Auth = append(c.SSHConfig.Auth, ssh.Password(password)) } - -// SetLocalScript reads a script file contents into the Job config. -func (j *Job) SetLocalScript(filename string, args string) error { - var err error - j.Script, err = ioutil.ReadFile(filename) - if err != nil { - return fmt.Errorf("failed to read script file") - } - j.ScriptArgs = args - - return nil -} diff --git a/job.go b/job.go index f97c3e5..ae00ff3 100644 --- a/job.go +++ b/job.go @@ -1,18 +1,30 @@ package massh import ( - "bytes" "fmt" - "golang.org/x/crypto/ssh" + "io/ioutil" ) -// getJob determines the type of job and returns the command string -func getJob(s *ssh.Session, j *Job) string { - // Set up remote script - if j.Script != nil { - s.Stdin = bytes.NewReader(j.Script) - return fmt.Sprintf("cat > outfile.sh && chmod +x ./outfile.sh && ./outfile.sh %s && rm ./outfile.sh", j.ScriptArgs) +// Job is a single remote task config. For script files, use Job.SetLocalScript(). +type Job struct { + Command string + Script []byte + ScriptArgs string +} + +// SetCommand sets the Command value in Job. This is the Command executed over SSH to all hosts. +func (j *Job) SetCommand(command string) { + j.Command = command +} + +// SetLocalScript reads a script file contents into the Job config. +func (j *Job) SetLocalScript(filename string, args string) error { + var err error + j.Script, err = ioutil.ReadFile(filename) + if err != nil { + return fmt.Errorf("failed to read script file") } + j.ScriptArgs = args - return j.Command + return nil } diff --git a/session.go b/session.go index 3d78f0a..8f3c650 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "fmt" + "golang.org/x/crypto/ssh" "io" "sync" ) @@ -33,6 +34,17 @@ type Result struct { DoneChannel chan struct{} } +// getJob determines the type of job and returns the command string +func getJob(s *ssh.Session, j *Job) string { + // Set up remote script + if j.Script != nil { + s.Stdin = bytes.NewReader(j.Script) + return fmt.Sprintf("cat > outfile.sh && chmod +x ./outfile.sh && ./outfile.sh %s && rm ./outfile.sh", j.ScriptArgs) + } + + return j.Command +} + // sshCommand runs an SSH task and returns Result only when the command has finished executing. func sshCommand(host string, config *Config) Result { var r Result From 2a3964c5b3c0a12edf336a32a76a04cfe3f19be4 Mon Sep 17 00:00:00 2001 From: Disco Date: Sun, 12 Dec 2021 21:52:58 +0000 Subject: [PATCH 2/3] Updated examples --- _examples/bulk_return/main.go | 41 +++++++++ _examples/example_bulk_return/command.go | 76 ---------------- _examples/example_bulk_return/main.go | 58 ------------ .../main.go | 13 ++- _examples/python_script/main.go | 46 ++++++++++ _examples/python_script/script.py | 3 + _examples/shell_script/main.go | 46 ++++++++++ _examples/shell_script/script.sh | 2 + .../{example_streaming => streaming}/main.go | 12 +-- _examples/streaming_script/main.go | 83 +++++++++++++++++ _examples/streaming_script/script.sh | 2 + _examples/streaming_script_jobstack/main.go | 91 +++++++++++++++++++ .../streaming_script_jobstack/script1.sh | 2 + .../streaming_script_jobstack/script2.sh | 2 + .../streaming_script_jobstack/script3.sh | 2 + 15 files changed, 331 insertions(+), 148 deletions(-) create mode 100644 _examples/bulk_return/main.go delete mode 100644 _examples/example_bulk_return/command.go delete mode 100644 _examples/example_bulk_return/main.go rename _examples/{example_jobstack_streaming => jobstack_streaming}/main.go (87%) create mode 100644 _examples/python_script/main.go create mode 100644 _examples/python_script/script.py create mode 100644 _examples/shell_script/main.go create mode 100644 _examples/shell_script/script.sh rename _examples/{example_streaming => streaming}/main.go (91%) create mode 100644 _examples/streaming_script/main.go create mode 100644 _examples/streaming_script/script.sh create mode 100644 _examples/streaming_script_jobstack/main.go create mode 100644 _examples/streaming_script_jobstack/script1.sh create mode 100644 _examples/streaming_script_jobstack/script2.sh create mode 100644 _examples/streaming_script_jobstack/script3.sh diff --git a/_examples/bulk_return/main.go b/_examples/bulk_return/main.go new file mode 100644 index 0000000..b00cac1 --- /dev/null +++ b/_examples/bulk_return/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "fmt" + "github.com/discoriver/massh" + "golang.org/x/crypto/ssh" + "time" +) + +func main() { + j := &massh.Job{ + Command: "echo \"Hello, World\"", + } + + sshc := &ssh.ClientConfig{ + // Fake credentials + User: "u01", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: time.Duration(2) * time.Second, + } + + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.Job = j + cfg.WorkerPool = 10 + cfg.SetHosts([]string{"192.168.1.118", "192.168.1.123"}) + + res, err := cfg.Run() + if err != nil { + panic(err) + } + + for i := range res { + if res[i].Error != nil { + fmt.Printf("%s: %s\n", res[i].Host, res[i].Error) + } else { + fmt.Printf("%s: %s", res[i].Host, res[i].Output) + } + } +} diff --git a/_examples/example_bulk_return/command.go b/_examples/example_bulk_return/command.go deleted file mode 100644 index 1a43621..0000000 --- a/_examples/example_bulk_return/command.go +++ /dev/null @@ -1,76 +0,0 @@ -package main - -import ( - "bufio" - "flag" - "fmt" - "io" - "os" - "strings" -) - -var command cmdEnv - -type cmdEnv struct { - Hosts []string - Script string - ScriptArgs string - WorkerPool int - User string - Timeout int - PrivateKey string - Insecure bool - Command string -} - -func parseCommands() { - flag.StringVar(&command.Script, "s", "", "Path to script file. Overrides -c switch.") - flag.StringVar(&command.ScriptArgs, "a", "", "Arguments for script") - flag.IntVar(&command.WorkerPool, "w", 5, "Specify amount of concurrent workers.") - flag.StringVar(&command.User, "u", "", "Specify user for ssh.") - flag.IntVar(&command.Timeout, "t", 10, "Timeout for ssh.") - flag.StringVar(&command.PrivateKey, "p", "", "Public key file.") - flag.BoolVar(&command.Insecure, "insecure", false, "Set insecure key mode.") - flag.StringVar(&command.Command, "c", "", "Set remote command to run.") - - if len(os.Args) < 2 { - flag.Usage() - os.Exit(0) - } - - flag.Parse() - parseHosts() -} - -func parseHosts() { - CheckStdin() - reader := bufio.NewReader(os.Stdin) - - for { - input, err := reader.ReadString('\n') - if err != nil && err == io.EOF { - if len(command.Hosts) == 0 { - fmt.Println("no hosts provided") - os.Exit(1) - } else { - return - } - } - command.Hosts = append(command.Hosts, strings.Trim(input, "\n")) - } -} - -// CheckStdin ensures os.Stdin is available, and checks the pipe for potential errors. -func CheckStdin() { - stdin, err := os.Stdin.Stat() - if err != nil { - fmt.Printf("Could not read stdin: %s", err) - os.Exit(1) - } - - // Make sure pipe is good - if stdin.Mode()&os.ModeCharDevice != 0 { - fmt.Println("bad pipe or no hosts provided:", stdin.Size()) - os.Exit(1) - } -} diff --git a/_examples/example_bulk_return/main.go b/_examples/example_bulk_return/main.go deleted file mode 100644 index 918b7c6..0000000 --- a/_examples/example_bulk_return/main.go +++ /dev/null @@ -1,58 +0,0 @@ -package main - -import ( - "fmt" - "github.com/discoriver/massh" - "golang.org/x/crypto/ssh" - "os" - "time" -) - -/* -right now everything here is designed as a proof of concept. Things in main need to be worked out, -but for now simply proving that the massh package is behaving as expected is enough. -*/ -func main() { - parseCommands() - - //mConfig := masshConfigBuilder() - mConfig := massh.Config{} - - if err := mConfig.CheckSanity(); err != nil { - fmt.Printf("%s\n", err) - os.Exit(0) - } - fmt.Print(mConfig.Run()) -} - -func masshConfigBuilder() *massh.Config { - config := &massh.Config{ - SSHConfig: &ssh.ClientConfig{ - User: command.User, - Auth: []ssh.AuthMethod{}, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Timeout: time.Duration(command.Timeout) * time.Second, - }, - Job: &massh.Job{}, - WorkerPool: command.WorkerPool, - } - config.SetHosts(command.Hosts) - - if command.PrivateKey != "" { - if err := config.SetPrivateKeyAuth(command.PrivateKey, ""); err != nil { - fmt.Println(err) - os.Exit(1) - } - } - - if command.Script != "" { - err := config.Job.SetLocalScript(command.Script, command.ScriptArgs) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - } else { - config.Job.SetCommand(command.Command) - } - return config -} diff --git a/_examples/example_jobstack_streaming/main.go b/_examples/jobstack_streaming/main.go similarity index 87% rename from _examples/example_jobstack_streaming/main.go rename to _examples/jobstack_streaming/main.go index 8c10294..f40fc05 100644 --- a/_examples/example_jobstack_streaming/main.go +++ b/_examples/jobstack_streaming/main.go @@ -29,13 +29,12 @@ func main() { Timeout: time.Duration(2) * time.Second, } - cfg := &massh.Config{ - // In this example I was testing with two working hosts, and two non-existent IPs. - SSHConfig: sshc, - JobStack: &[]massh.Job{j, j2, j3}, - WorkerPool: 10, - } - cfg.SetHosts([]string{"192.168.1.119", "192.168.1.120", "192.168.1.129", "192.168.1.212"}) + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.JobStack = &[]massh.Job{j, j2, j3} + cfg.WorkerPool = 10 + cfg.SetHosts([]string{"192.168.1.118"}) + resChan := make(chan massh.Result) diff --git a/_examples/python_script/main.go b/_examples/python_script/main.go new file mode 100644 index 0000000..9fb762f --- /dev/null +++ b/_examples/python_script/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "github.com/discoriver/massh" + "golang.org/x/crypto/ssh" + "time" +) + +func main() { + j := &massh.Job{ + Command: "echo \"Hello, World\"", + } + + sshc := &ssh.ClientConfig{ + // Fake credentials + User: "u01", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: time.Duration(2) * time.Second, + } + + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.Job = j + cfg.WorkerPool = 10 + cfg.SetHosts([]string{"192.168.1.118", "192.168.1.123"}) + + err := cfg.Job.SetScript("script.py", "") + if err != nil { + panic(err) + } + + res, err := cfg.Run() + if err != nil { + panic(err) + } + + for i := range res { + if res[i].Error != nil { + fmt.Printf("%s: %s\n", res[i].Host, res[i].Error) + } else { + fmt.Printf("%s: %s", res[i].Host, res[i].Output) + } + } +} diff --git a/_examples/python_script/script.py b/_examples/python_script/script.py new file mode 100644 index 0000000..61e0341 --- /dev/null +++ b/_examples/python_script/script.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +import sys +sys.stdout.write("Hello world, from python script." + '\n') \ No newline at end of file diff --git a/_examples/shell_script/main.go b/_examples/shell_script/main.go new file mode 100644 index 0000000..096b986 --- /dev/null +++ b/_examples/shell_script/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "github.com/discoriver/massh" + "golang.org/x/crypto/ssh" + "time" +) + +func main() { + j := &massh.Job{ + Command: "echo \"Hello, World\"", + } + + sshc := &ssh.ClientConfig{ + // Fake credentials + User: "u01", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: time.Duration(2) * time.Second, + } + + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.Job = j + cfg.WorkerPool = 10 + cfg.SetHosts([]string{"192.168.1.118", "192.168.1.123"}) + + err := cfg.Job.SetScript("script.sh", "") + if err != nil { + panic(err) + } + + res, err := cfg.Run() + if err != nil { + panic(err) + } + + for i := range res { + if res[i].Error != nil { + fmt.Printf("%s: %s\n", res[i].Host, res[i].Error) + } else { + fmt.Printf("%s: %s", res[i].Host, res[i].Output) + } + } +} diff --git a/_examples/shell_script/script.sh b/_examples/shell_script/script.sh new file mode 100644 index 0000000..d0de4fa --- /dev/null +++ b/_examples/shell_script/script.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Hello World, from shell script" \ No newline at end of file diff --git a/_examples/example_streaming/main.go b/_examples/streaming/main.go similarity index 91% rename from _examples/example_streaming/main.go rename to _examples/streaming/main.go index afade3c..7107afe 100644 --- a/_examples/example_streaming/main.go +++ b/_examples/streaming/main.go @@ -21,13 +21,11 @@ func main() { Timeout: time.Duration(2) * time.Second, } - cfg := &massh.Config{ - // In this example I was testing with two working hosts, and two non-existent IPs. - SSHConfig: sshc, - Job: j, - WorkerPool: 10, - } - cfg.SetHosts([]string{"192.168.1.119", "192.168.1.120", "192.168.1.129", "192.168.1.212"}) + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.Job = j + cfg.WorkerPool = 10 + cfg.SetHosts([]string{"192.168.1.118", "192.168.1.119", "192.168.1.120", "192.168.1.129", "192.168.1.212"}) resChan := make(chan massh.Result) diff --git a/_examples/streaming_script/main.go b/_examples/streaming_script/main.go new file mode 100644 index 0000000..3a74282 --- /dev/null +++ b/_examples/streaming_script/main.go @@ -0,0 +1,83 @@ +package main + +import ( + "fmt" + "github.com/discoriver/massh" + "golang.org/x/crypto/ssh" + "sync" + "time" +) + +func main() { + sshc := &ssh.ClientConfig{ + // Fake credentials + User: "u01", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: time.Duration(2) * time.Second, + } + + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.WorkerPool = 10 + cfg.Job = &massh.Job{} + cfg.SetHosts([]string{"192.168.1.118"}) + cfg.Job.SetScript("script.sh", "") + + resChan := make(chan massh.Result) + + // This should be the last responsibility from the massh package. Handling the Result channel is up to the user. + err := cfg.Stream(resChan) + if err != nil { + panic(err) + } + + var wg sync.WaitGroup + // This can probably be cleaner. We're hindered somewhat, I think, by reading a channel from a channel. + for { + select { + case result := <-resChan: + wg.Add(1) + go func() { + // Need to handle any errors as the existence of this value indicates that the ssh task wasn't started + // due to some functional error. + // + // The reason for this design is that it was important to me not to have the cfg.Stream function return + // anything, and having it as part of the Result means we can more easily associate the error with a + // host. + if result.Error != nil { + fmt.Printf("%s: %s\n", result.Host, result.Error) + wg.Done() + } else { + readStream(result, &wg) + } + }() + default: + if massh.NumberOfStreamingHostsCompleted == len(cfg.Hosts) { + // We want to wait for all goroutines to complete before we declare that the work is finished, as + // it's possible for us to execute this code before the gofunc above has completed if left unchecked. + wg.Wait() + + // This should always be the last thing written. Waiting above ensures this. + fmt.Println("Everything returned.") + return + } + } + } +} + +// Read Stdout stream +func readStream(res massh.Result, wg *sync.WaitGroup) error { + for { + select { + case d := <-res.StdOutStream: + fmt.Printf("STDOUT %s: %s", res.Host, d) + case e := <-res.StdErrStream: + fmt.Printf("STDERR %s: %s", res.Host, e) + case <-res.DoneChannel: + // Confirm that the host has exited. + fmt.Printf("%s: Finished\n", res.Host) + wg.Done() + } + } +} diff --git a/_examples/streaming_script/script.sh b/_examples/streaming_script/script.sh new file mode 100644 index 0000000..d0de4fa --- /dev/null +++ b/_examples/streaming_script/script.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Hello World, from shell script" \ No newline at end of file diff --git a/_examples/streaming_script_jobstack/main.go b/_examples/streaming_script_jobstack/main.go new file mode 100644 index 0000000..c1b9a4b --- /dev/null +++ b/_examples/streaming_script_jobstack/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "fmt" + "github.com/discoriver/massh" + "golang.org/x/crypto/ssh" + "sync" + "time" +) + +func main() { + j1 := massh.Job{} + j2 := massh.Job{} + j3 := massh.Job{} + + j1.SetScript("script1.sh", "") + j2.SetScript("script2.sh", "") + j3.SetScript("script3.sh", "") + + sshc := &ssh.ClientConfig{ + // Fake credentials + User: "u01", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: time.Duration(2) * time.Second, + } + + cfg := massh.NewConfig() + cfg.SSHConfig = sshc + cfg.WorkerPool = 10 + cfg.JobStack = &[]massh.Job{j1, j2, j3} + cfg.SetHosts([]string{"192.168.1.118"}) + + resChan := make(chan massh.Result) + + // This should be the last responsibility from the massh package. Handling the Result channel is up to the user. + err := cfg.Stream(resChan) + if err != nil { + panic(err) + } + + var wg sync.WaitGroup + numberOfExpectedCompletions := len(cfg.Hosts) * len(*cfg.JobStack) + // This can probably be cleaner. We're hindered somewhat, I think, by reading a channel from a channel. + for { + select { + case result := <-resChan: + wg.Add(1) + go func() { + // Need to handle any errors as the existence of this value indicates that the ssh task wasn't started + // due to some functional error. + // + // The reason for this design is that it was important to me not to have the cfg.Stream function return + // anything, and having it as part of the Result means we can more easily associate the error with a + // host. + if result.Error != nil { + fmt.Printf("%s: %s\n", result.Host, result.Error) + wg.Done() + } else { + readStream(result, &wg) + } + }() + default: + if massh.NumberOfStreamingHostsCompleted == numberOfExpectedCompletions { + // We want to wait for all goroutines to complete before we declare that the work is finished, as + // it's possible for us to execute this code before the gofunc above has completed if left unchecked. + wg.Wait() + + // This should always be the last thing written. Waiting above ensures this. + fmt.Println("Everything returned.") + return + } + } + } +} + +// Read Stdout stream +func readStream(res massh.Result, wg *sync.WaitGroup) error { + for { + select { + case d := <-res.StdOutStream: + fmt.Printf("STDOUT %s: %s", res.Host, d) + case e := <-res.StdErrStream: + fmt.Printf("STDERR %s: %s", res.Host, e) + case <-res.DoneChannel: + // Confirm that the host has exited. + fmt.Printf("%s: Finished\n", res.Host) + wg.Done() + } + } +} diff --git a/_examples/streaming_script_jobstack/script1.sh b/_examples/streaming_script_jobstack/script1.sh new file mode 100644 index 0000000..7631bea --- /dev/null +++ b/_examples/streaming_script_jobstack/script1.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Hello World, from the first shell script" \ No newline at end of file diff --git a/_examples/streaming_script_jobstack/script2.sh b/_examples/streaming_script_jobstack/script2.sh new file mode 100644 index 0000000..bae1633 --- /dev/null +++ b/_examples/streaming_script_jobstack/script2.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Hello World, from the second shell script" \ No newline at end of file diff --git a/_examples/streaming_script_jobstack/script3.sh b/_examples/streaming_script_jobstack/script3.sh new file mode 100644 index 0000000..e97a8bf --- /dev/null +++ b/_examples/streaming_script_jobstack/script3.sh @@ -0,0 +1,2 @@ +#!/bin/bash +echo "Hello World, from the third shell script" \ No newline at end of file From 712533871faedea3c764be6501eadb7939ad97bd Mon Sep 17 00:00:00 2001 From: Disco Date: Sun, 12 Dec 2021 21:56:25 +0000 Subject: [PATCH 3/3] Updated script logic with interface. Fixed bug where Jobstack was only running the last entry in the stack, repeated for it's length, due to incorrectly copied config struct. --- job.go | 21 ++++------ script.go | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++++ session.go | 45 +++++++++++++++------- 3 files changed, 148 insertions(+), 28 deletions(-) create mode 100644 script.go diff --git a/job.go b/job.go index ae00ff3..b6142dc 100644 --- a/job.go +++ b/job.go @@ -1,15 +1,9 @@ package massh -import ( - "fmt" - "io/ioutil" -) - // Job is a single remote task config. For script files, use Job.SetLocalScript(). type Job struct { - Command string - Script []byte - ScriptArgs string + Command string + Script script } // SetCommand sets the Command value in Job. This is the Command executed over SSH to all hosts. @@ -17,14 +11,13 @@ func (j *Job) SetCommand(command string) { j.Command = command } -// SetLocalScript reads a script file contents into the Job config. -func (j *Job) SetLocalScript(filename string, args string) error { - var err error - j.Script, err = ioutil.ReadFile(filename) +func (j *Job) SetScript(filePath string, args ...string) error { + s, err := newScript(filePath, args...) if err != nil { - return fmt.Errorf("failed to read script file") + return err } - j.ScriptArgs = args + + j.Script = s return nil } diff --git a/script.go b/script.go new file mode 100644 index 0000000..d30a73a --- /dev/null +++ b/script.go @@ -0,0 +1,110 @@ +package massh + +import ( + "bytes" + "fmt" + "golang.org/x/crypto/ssh" + "io/ioutil" + "path/filepath" + "strings" +) + +type script interface { + prepare(*ssh.Session) + getPreparedCommandString() string + getBytes() []byte + getArgs() string +} + +type shell struct { + bytes []byte + args string + + commandString string + + prepared bool +} + +type python struct { + bytes []byte + args string + + commandString string + + prepared bool +} + +// NewScript creates a new script type based on the file extension. Shebang line in supported scripts must be present. +// +// Each element in args should ideally contain an argument's key/value, for example "--some-arg value", or "--some-arg=value". +func newScript(scriptFile string, args ...string) (script, error) { + scriptBytes, err := ioutil.ReadFile(scriptFile) + if err != nil { + return nil, fmt.Errorf("failed to read script file: %s", err) + } + + // Check shebang is present + if scriptBytes[0] != '#' { + return nil, fmt.Errorf("shebang line not present in file %s", filepath.Base(scriptFile)) + } + + if strings.HasSuffix(scriptFile, ".sh") { + shellScript := &shell{ + bytes: scriptBytes, + args: strings.Join(args, " "), + } + return shellScript, nil + } + + if strings.HasSuffix(scriptFile, ".py") { + pythonScript := &python{ + bytes: scriptBytes, + args: strings.Join(args, " "), + } + return pythonScript, nil + } + + return nil, fmt.Errorf("script file %s not supported", filepath.Base(scriptFile)) +} + +// Prepare populated the SSH sessions's stdin with the script data, and returns a command string to run the script from a temporary file. +func (s *shell) prepare(session *ssh.Session) { + // Set up remote script + session.Stdin = bytes.NewReader(s.bytes) + + s.commandString = fmt.Sprintf("cat > massh-script-tmp.sh && chmod +x ./massh-script-tmp.sh && ./massh-script-tmp.sh %s && rm ./massh-script-tmp.sh", s.args) + s.prepared = true +} + +func (s *shell) getPreparedCommandString() string { + return s.commandString +} + +func (s *shell) getBytes() []byte { + return s.bytes +} + +func (s *shell) getArgs() string { + return s.args +} + +// Prepare populated the SSH sessions's stdin with the script data, and returns a command string to run the script from a temporary file. +func (s *python) prepare(session *ssh.Session) { + // Set up remote script + session.Stdin = bytes.NewReader(s.bytes) + + s.commandString = fmt.Sprintf("cat > massh-script-tmp.py && chmod +x ./massh-script-tmp.py && ./massh-script-tmp.py %s && rm ./massh-script-tmp.py", s.args) + s.prepared = true +} + +func (s *python) getPreparedCommandString() string { + return s.commandString +} + +func (s *python) getBytes() []byte { + return s.bytes +} + +func (s *python) getArgs() string { + return s.args +} diff --git a/session.go b/session.go index 8f3c650..20bdfae 100644 --- a/session.go +++ b/session.go @@ -34,12 +34,12 @@ type Result struct { DoneChannel chan struct{} } -// getJob determines the type of job and returns the command string +// getJob determines the type of job and returns the command string. If type is a local script, then stdin will be populated with the script data and sent/executed on the remote machine. func getJob(s *ssh.Session, j *Job) string { // Set up remote script if j.Script != nil { - s.Stdin = bytes.NewReader(j.Script) - return fmt.Sprintf("cat > outfile.sh && chmod +x ./outfile.sh && ./outfile.sh %s && rm ./outfile.sh", j.ScriptArgs) + j.Script.prepare(s) + return j.Script.getPreparedCommandString() } return j.Command @@ -166,7 +166,7 @@ func sshCommandStream(host string, config *Config, resultChannel chan Result) { // readToBytesChannel reads from io.Reader and directs the data to a byte slice channel for streaming. func readToBytesChannel(reader io.Reader, stream chan []byte, r Result, wg *sync.WaitGroup) { - defer func(){ wg.Done() }() + defer func() { wg.Done() }() rdr := bufio.NewReader(reader) @@ -192,12 +192,15 @@ func worker(hosts <-chan string, results chan<- Result, config *Config, resChan // TODO: Make the handling of a JobStack more elegant. if resChan == nil { for host := range hosts { - cfg := *config - if cfg.JobStack != nil { - for i := range *cfg.JobStack { - j := (*cfg.JobStack)[i] + if config.JobStack != nil { + for i := range *config.JobStack { + // Cfg is a copy of config, without job pointers. This is needed to separate the jobstack. + cfg := copyConfigNoJobs(config) + + j := (*config.JobStack)[i] cfg.Job = &j - results <- sshCommand(host, &cfg) + + results <- sshCommand(host, cfg) } } else { results <- sshCommand(host, config) @@ -205,12 +208,15 @@ func worker(hosts <-chan string, results chan<- Result, config *Config, resChan } } else { for host := range hosts { - cfg := *config - if cfg.JobStack != nil { - for i := range *cfg.JobStack { - j := (*cfg.JobStack)[i] + if config.JobStack != nil { + for i := range *config.JobStack { + // Cfg is a copy of config, without job pointers. This is needed to separate the jobstack. + cfg := copyConfigNoJobs(config) + + j := (*config.JobStack)[i] cfg.Job = &j - go sshCommandStream(host, &cfg, resChan) + + go sshCommandStream(host, cfg, resChan) } } else { go sshCommandStream(host, config, resChan) @@ -219,6 +225,17 @@ func worker(hosts <-chan string, results chan<- Result, config *Config, resChan } } +func copyConfigNoJobs(config *Config) *Config { + cfg := NewConfig() + cfg.Hosts = config.Hosts + cfg.SSHConfig = config.SSHConfig + cfg.BastionHost = config.BastionHost + cfg.BastionHostSSHConfig = config.BastionHostSSHConfig + cfg.WorkerPool = config.WorkerPool + + return cfg +} + // runStream is mostly the same as run, except it directs the results to a channel so they can be processed // before the command has completed executing (i.e streaming the stdout and stderr as it runs). func runStream(c *Config, rs chan Result) {