Skip to content

Commit

Permalink
Add a feature for set environment variables
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnavi committed Dec 14, 2024
1 parent 19dee2a commit e5cc864
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 36 deletions.
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Komandan has `komando` function that takes two arguments:
- `ignore_exit_code`: a boolean that indicates whether to ignore the exit code of the task. If `true`, the script will continue even if the task returns a non-zero exit code. (default is `false`)
- `elevate`: a boolean that indicates whether to run the task as root. (default is `false`)
- `as_user`: a string that specifies the user to run the task as. Requires `elevate` to be `true`. (optional)
- `env`: a table that contains environment variables to be set before executing the task. (optional)

This function will execute the module on the target server and return the results:
- `stdout`: a string that contains the standard output of the module.
Expand All @@ -83,30 +84,68 @@ This function will execute the module on the target server and return the result
## Modules

Komandan has several built-in modules that can be used to perform various tasks on the target server. These modules are located in the `komandan.modules` table.

### `cmd` module

The `cmd` module allows you to execute a shell command on the target server. It takes the following arguments:
- `cmd`: a string that contains the shell command to be executed.

Usage example:
```lua
local task = komandan.modules.cmd({
cmd = "mkdir /tmp/newdir"
})
```

### `script` module

The `script` module allows you to execute a script on the target server. It takes the following arguments:
- `script`: a string that contains the script to be executed.
- `from_file`: a string that contains the local path to the script file to be executed on the target server. (`script` and `from_file` parameters are mutually exclusive)
- `interpreter`: a string that specifies the interpreter to use for the script. If not specified, the script will be executed using the default shell.

Usage example:

```lua
local task = komandan.modules.script({
script = "print('Hello from Komandan!')"
-- or
from_file = "/local_path/to/script.py"

interpreter = "python3"
})
```

### `upload` module

The `upload` module allows you to upload a file to the target server. It takes the following arguments:
- `src`: a string that contains the path to the file to be uploaded.
- `dst`: a string that contains the path to the destination file on the target server.

Usage example:

```lua
local task = komandan.modules.upload({
src = "/local_path/to/file.txt",
dst = "/remote_path/to/file.txt"
})
```

### `download` module

The `download` module allows you to download a file from the target server. It takes the following arguments:
- `src`: a string that contains the path to the file to be downloaded.
- `dst`: a string that contains the path to the destination file on the local machine.

Usage example:

```lua
local task = komandan.modules.download({
src = "/remote_path/to/file.txt",
dst = "/local_path/to/file.txt"
})
```

### `apt` module

The `apt` module allows you to install packages on the target server. It takes the following arguments:
Expand All @@ -115,6 +154,16 @@ The `apt` module allows you to install packages on the target server. It takes t
- `update_cache`: a boolean that indicates whether to update the package cache before installing the package. (default is `false`)
- `install_recommends`: a boolean that indicates whether to install recommended packages. (default is `true`)

Usage example:

```lua
local task = komandan.modules.apt({
package = "nginx",
update_cache = true,
elevate = true
})
```

## Built-in functions

Komandan provides several built-in functions that can be used to help write scripts.
Expand Down
4 changes: 4 additions & 0 deletions src/defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,9 @@ pub fn defaults(lua: &Lua) -> mlua::Result<Table> {
defaults.set("elevate", false)?;
defaults.set("elevation_method", "sudo")?;

let env = lua.create_table()?;
env.set("DEBIAN_FRONTEND", "noninteractive")?;
defaults.set("env", env)?;

Ok(defaults)
}
29 changes: 23 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,35 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result<Table>
as_user,
};

let ssh = SSHSession::connect(
let mut ssh = SSHSession::connect(
(host.get::<String>("address")?.as_str(), port),
&user,
ssh_auth_method,
elevation,
)?;

let env_defaults = defaults.get::<Table>("env").unwrap_or(lua.create_table()?);
let env_host = host.get::<Table>("env").unwrap_or(lua.create_table()?);
let env_task = task.get::<Table>("env").unwrap_or(lua.create_table()?);

for pair in env_defaults.pairs() {
let (key, value): (String, String) = pair?;
println!("{}={}", key, value);
ssh.set_env(&key, &value);
}

for pair in env_host.pairs() {
let (key, value): (String, String) = pair?;
println!("{}={}", key, value);
ssh.set_env(&key, &value);
}

for pair in env_task.pairs() {
let (key, value): (String, String) = pair?;
println!("{}={}", key, value);
ssh.set_env(&key, &value);
}

let module_clone = module.clone();
let results = lua
.load(chunk! {
Expand All @@ -285,11 +307,6 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result<Table>
})
.eval::<()>()?;

let defaults = lua
.globals()
.get::<Table>("komandan")?
.get::<Table>("defaults")?;

let ignore_exit_code = task
.get::<bool>("ignore_exit_code")
.unwrap_or_else(|_| defaults.get::<bool>("ignore_exit_code").unwrap());
Expand Down
70 changes: 40 additions & 30 deletions src/ssh.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
collections::HashMap,
fs,
io::{self, Read, Write},
net::{TcpStream, ToSocketAddrs},
Expand Down Expand Up @@ -30,6 +31,7 @@ pub enum ElevateMethod {

pub struct SSHSession {
session: Session,
env: HashMap<String, String>,
elevation: Elevation,
stdout: Option<String>,
stderr: Option<String>,
Expand Down Expand Up @@ -72,24 +74,38 @@ impl SSHSession {

Ok(Self {
session,
env: HashMap::new(),
elevation,
stdout: Some(String::new()),
stderr: Some(String::new()),
exit_code: Some(0),
})
}

pub fn set_env(&mut self, key: &str, value: &str) {
*self.env.entry(key.to_string()).or_insert(value.to_string()) = value.to_string();
}

fn execute_command(&mut self, command: &str) -> Result<ssh2::Channel> {
let mut channel = self.session.channel_session()?;
let mut command = command.to_string();
for (key, value) in &self.env {
command = format!("export {}={}\n", key, value) + &command;
}
channel.exec(command.as_str())?;
Ok(channel)
}

pub fn cmd(&mut self, command: &str) -> Result<(String, String, i32)> {
let mut channel = self.session.channel_session().unwrap();
channel.exec(command).unwrap();
let mut channel = self.execute_command(command)?;
let mut stdout = String::new();
let mut stderr = String::new();

channel.read_to_string(&mut stdout).unwrap();
channel.stderr().read_to_string(&mut stderr).unwrap();
channel.read_to_string(&mut stdout)?;
channel.stderr().read_to_string(&mut stderr)?;
stdout = stdout.trim_end_matches('\n').to_string();
channel.wait_close().unwrap();
let exit_code = channel.exit_status().unwrap();
channel.wait_close()?;
let exit_code = channel.exit_status()?;

self.stdout.as_mut().unwrap().push_str(&stdout);
self.stderr.as_mut().unwrap().push_str(&stderr);
Expand All @@ -115,32 +131,27 @@ impl SSHSession {
}

pub fn get_remote_env(&mut self, var: &str) -> Result<String> {
let mut channel = self.session.channel_session().unwrap();
channel.exec(format!("echo ${}", var).as_str()).unwrap();
let mut channel = self.execute_command(format!("echo ${}", var).as_str())?;
let mut stdout = String::new();
channel.read_to_string(&mut stdout).unwrap();
channel.read_to_string(&mut stdout)?;
stdout = stdout.trim_end_matches('\n').to_string();
channel.wait_close().unwrap();
channel.wait_close()?;

Ok(stdout)
}

pub fn get_tmpdir(&mut self) -> Result<String> {
let mut channel = self.session.channel_session().unwrap();
channel.exec("tmpdir=`for dir in \"$HOME/.komandan/tmp\" \"/tmp/komandan\"; do if [ -d \"$dir\" ] || mkdir -p \"$dir\" 2>/dev/null; then echo \"$dir\"; break; fi; done`; [ -z \"$tmpdir\" ] && { exit 1; } || echo \"$tmpdir\"").unwrap();
let mut channel = self.execute_command("tmpdir=`for dir in \"$HOME/.komandan/tmp\" \"/tmp/komandan\"; do if [ -d \"$dir\" ] || mkdir -p \"$dir\" 2>/dev/null; then echo \"$dir\"; break; fi; done`; [ -z \"$tmpdir\" ] && { exit 1; } || echo \"$tmpdir\"")?;
let mut stdout = String::new();
channel.read_to_string(&mut stdout).unwrap();
channel.read_to_string(&mut stdout)?;
stdout = stdout.trim_end_matches('\n').to_string();
channel.wait_close().unwrap();
channel.wait_close()?;

Ok(stdout)
}

pub fn chmod(&mut self, remote_path: &Path, mode: &String) -> Result<()> {
let mut channel = self.session.channel_session().unwrap();
channel
.exec(format!("chmod {} {}", mode, remote_path.to_string_lossy()).as_str())
.unwrap();
self.execute_command(format!("chmod {} {}", mode, remote_path.to_string_lossy()).as_str())?;

Ok(())
}
Expand Down Expand Up @@ -171,15 +182,14 @@ impl SSHSession {

pub fn write_remote_file(&mut self, remote_path: &str, content: &[u8]) -> Result<()> {
let content_length = content.len() as u64;
let mut remote_file = self
.session
.scp_send(Path::new(remote_path), 0o644, content_length, None)
.unwrap();
remote_file.write(content).unwrap();
remote_file.send_eof().unwrap();
remote_file.wait_eof().unwrap();
remote_file.close().unwrap();
remote_file.wait_close().unwrap();
let mut remote_file =
self.session
.scp_send(Path::new(remote_path), 0o644, content_length, None)?;
remote_file.write(content)?;
remote_file.send_eof()?;
remote_file.wait_eof()?;
remote_file.close()?;
remote_file.wait_close()?;

Ok(())
}
Expand Down Expand Up @@ -251,9 +261,9 @@ fn download_directory(sftp: &mut Sftp, remote_path: &Path, local_path: &Path) ->
impl UserData for SSHSession {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_method_mut("cmd", |lua, this, command: String| {
let command = this.prepare_command(command.as_str()).unwrap();
let command = this.prepare_command(command.as_str())?;
let cmd_result = this.cmd(&command);
let (stdout, stderr, exit_code) = cmd_result.unwrap();
let (stdout, stderr, exit_code) = cmd_result?;

let table = lua.create_table()?;
table.set("stdout", stdout)?;
Expand Down Expand Up @@ -299,7 +309,7 @@ impl UserData for SSHSession {
});

methods.add_method_mut("get_tmpdir", |_, this, ()| {
let tmpdir = this.get_tmpdir().unwrap();
let tmpdir = this.get_tmpdir()?;
Ok(tmpdir)
});

Expand Down

0 comments on commit e5cc864

Please sign in to comment.