From 712533871faedea3c764be6501eadb7939ad97bd Mon Sep 17 00:00:00 2001 From: Disco Date: Sun, 12 Dec 2021 21:56:25 +0000 Subject: [PATCH] Updated script logic with interface. Fixed bug where Jobstack was only running the last entry in the stack, repeated for it's length, due to incorrectly copied config struct. --- job.go | 21 ++++------ script.go | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++++ session.go | 45 +++++++++++++++------- 3 files changed, 148 insertions(+), 28 deletions(-) create mode 100644 script.go diff --git a/job.go b/job.go index ae00ff3..b6142dc 100644 --- a/job.go +++ b/job.go @@ -1,15 +1,9 @@ package massh -import ( - "fmt" - "io/ioutil" -) - // Job is a single remote task config. For script files, use Job.SetLocalScript(). type Job struct { - Command string - Script []byte - ScriptArgs string + Command string + Script script } // SetCommand sets the Command value in Job. This is the Command executed over SSH to all hosts. @@ -17,14 +11,13 @@ func (j *Job) SetCommand(command string) { j.Command = command } -// SetLocalScript reads a script file contents into the Job config. -func (j *Job) SetLocalScript(filename string, args string) error { - var err error - j.Script, err = ioutil.ReadFile(filename) +func (j *Job) SetScript(filePath string, args ...string) error { + s, err := newScript(filePath, args...) if err != nil { - return fmt.Errorf("failed to read script file") + return err } - j.ScriptArgs = args + + j.Script = s return nil } diff --git a/script.go b/script.go new file mode 100644 index 0000000..d30a73a --- /dev/null +++ b/script.go @@ -0,0 +1,110 @@ +package massh + +import ( + "bytes" + "fmt" + "golang.org/x/crypto/ssh" + "io/ioutil" + "path/filepath" + "strings" +) + +type script interface { + prepare(*ssh.Session) + getPreparedCommandString() string + getBytes() []byte + getArgs() string +} + +type shell struct { + bytes []byte + args string + + commandString string + + prepared bool +} + +type python struct { + bytes []byte + args string + + commandString string + + prepared bool +} + +// NewScript creates a new script type based on the file extension. Shebang line in supported scripts must be present. +// +// Each element in args should ideally contain an argument's key/value, for example "--some-arg value", or "--some-arg=value". +func newScript(scriptFile string, args ...string) (script, error) { + scriptBytes, err := ioutil.ReadFile(scriptFile) + if err != nil { + return nil, fmt.Errorf("failed to read script file: %s", err) + } + + // Check shebang is present + if scriptBytes[0] != '#' { + return nil, fmt.Errorf("shebang line not present in file %s", filepath.Base(scriptFile)) + } + + if strings.HasSuffix(scriptFile, ".sh") { + shellScript := &shell{ + bytes: scriptBytes, + args: strings.Join(args, " "), + } + return shellScript, nil + } + + if strings.HasSuffix(scriptFile, ".py") { + pythonScript := &python{ + bytes: scriptBytes, + args: strings.Join(args, " "), + } + return pythonScript, nil + } + + return nil, fmt.Errorf("script file %s not supported", filepath.Base(scriptFile)) +} + +// Prepare populated the SSH sessions's stdin with the script data, and returns a command string to run the script from a temporary file. +func (s *shell) prepare(session *ssh.Session) { + // Set up remote script + session.Stdin = bytes.NewReader(s.bytes) + + s.commandString = fmt.Sprintf("cat > massh-script-tmp.sh && chmod +x ./massh-script-tmp.sh && ./massh-script-tmp.sh %s && rm ./massh-script-tmp.sh", s.args) + s.prepared = true +} + +func (s *shell) getPreparedCommandString() string { + return s.commandString +} + +func (s *shell) getBytes() []byte { + return s.bytes +} + +func (s *shell) getArgs() string { + return s.args +} + +// Prepare populated the SSH sessions's stdin with the script data, and returns a command string to run the script from a temporary file. +func (s *python) prepare(session *ssh.Session) { + // Set up remote script + session.Stdin = bytes.NewReader(s.bytes) + + s.commandString = fmt.Sprintf("cat > massh-script-tmp.py && chmod +x ./massh-script-tmp.py && ./massh-script-tmp.py %s && rm ./massh-script-tmp.py", s.args) + s.prepared = true +} + +func (s *python) getPreparedCommandString() string { + return s.commandString +} + +func (s *python) getBytes() []byte { + return s.bytes +} + +func (s *python) getArgs() string { + return s.args +} diff --git a/session.go b/session.go index 8f3c650..20bdfae 100644 --- a/session.go +++ b/session.go @@ -34,12 +34,12 @@ type Result struct { DoneChannel chan struct{} } -// getJob determines the type of job and returns the command string +// getJob determines the type of job and returns the command string. If type is a local script, then stdin will be populated with the script data and sent/executed on the remote machine. func getJob(s *ssh.Session, j *Job) string { // Set up remote script if j.Script != nil { - s.Stdin = bytes.NewReader(j.Script) - return fmt.Sprintf("cat > outfile.sh && chmod +x ./outfile.sh && ./outfile.sh %s && rm ./outfile.sh", j.ScriptArgs) + j.Script.prepare(s) + return j.Script.getPreparedCommandString() } return j.Command @@ -166,7 +166,7 @@ func sshCommandStream(host string, config *Config, resultChannel chan Result) { // readToBytesChannel reads from io.Reader and directs the data to a byte slice channel for streaming. func readToBytesChannel(reader io.Reader, stream chan []byte, r Result, wg *sync.WaitGroup) { - defer func(){ wg.Done() }() + defer func() { wg.Done() }() rdr := bufio.NewReader(reader) @@ -192,12 +192,15 @@ func worker(hosts <-chan string, results chan<- Result, config *Config, resChan // TODO: Make the handling of a JobStack more elegant. if resChan == nil { for host := range hosts { - cfg := *config - if cfg.JobStack != nil { - for i := range *cfg.JobStack { - j := (*cfg.JobStack)[i] + if config.JobStack != nil { + for i := range *config.JobStack { + // Cfg is a copy of config, without job pointers. This is needed to separate the jobstack. + cfg := copyConfigNoJobs(config) + + j := (*config.JobStack)[i] cfg.Job = &j - results <- sshCommand(host, &cfg) + + results <- sshCommand(host, cfg) } } else { results <- sshCommand(host, config) @@ -205,12 +208,15 @@ func worker(hosts <-chan string, results chan<- Result, config *Config, resChan } } else { for host := range hosts { - cfg := *config - if cfg.JobStack != nil { - for i := range *cfg.JobStack { - j := (*cfg.JobStack)[i] + if config.JobStack != nil { + for i := range *config.JobStack { + // Cfg is a copy of config, without job pointers. This is needed to separate the jobstack. + cfg := copyConfigNoJobs(config) + + j := (*config.JobStack)[i] cfg.Job = &j - go sshCommandStream(host, &cfg, resChan) + + go sshCommandStream(host, cfg, resChan) } } else { go sshCommandStream(host, config, resChan) @@ -219,6 +225,17 @@ func worker(hosts <-chan string, results chan<- Result, config *Config, resChan } } +func copyConfigNoJobs(config *Config) *Config { + cfg := NewConfig() + cfg.Hosts = config.Hosts + cfg.SSHConfig = config.SSHConfig + cfg.BastionHost = config.BastionHost + cfg.BastionHostSSHConfig = config.BastionHostSSHConfig + cfg.WorkerPool = config.WorkerPool + + return cfg +} + // runStream is mostly the same as run, except it directs the results to a channel so they can be processed // before the command has completed executing (i.e streaming the stdout and stderr as it runs). func runStream(c *Config, rs chan Result) {