Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnavi committed Dec 20, 2024
1 parent 2263032 commit 708b5b0
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 31 deletions.
35 changes: 13 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ use rustyline::DefaultEditor;
use ssh::{ElevateMethod, Elevation, SSHAuthMethod, SSHSession};
use std::{env, path::Path};
use util::{
dprint, filter_hosts, hostname_display, parse_hosts_json, regex_is_match, set_defaults,
dprint, filter_hosts, host_display, parse_hosts_json, regex_is_match, set_defaults,
task_display,
};
use validator::{validate_host, validate_module, validate_task};
use validator::{validate_host, validate_task};

pub fn setup_lua_env(lua: &Lua) -> mlua::Result<()> {
let args = Args::parse();
Expand Down Expand Up @@ -86,16 +87,10 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> {
async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result<Table> {
let host = lua.create_function(validate_host)?.call::<Table>(&host)?;
let task = lua.create_function(validate_task)?.call::<Table>(&task)?;
let module = lua
.create_function(validate_module)?
.call::<Table>(task.get::<Table>(1).unwrap())?;
let module = task.get::<Table>(1)?;

let host_display = hostname_display(&host);

let task_display = match task.get::<String>("name") {
Ok(name) => name,
Err(_) => module.get::<String>("name")?,
};
let host_display = host_display(&host);
let task_display = task_display(&task);

let defaults = lua
.globals()
Expand Down Expand Up @@ -287,32 +282,28 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result<Table>
ssh.set_env(&key, &value);
}

let module_clone = module.clone();
let results = lua
.load(chunk! {
print("Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...")
$module_clone.ssh = $ssh
$module_clone:run()
$module.ssh = $ssh
$module:run()

local results = $module_clone.ssh:get_session_results()
local results = $module.ssh:get_session_results()
komandan.dprint(results.stdout)
if results.exit_code ~= 0 then
print("Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. results.exit_code .. ": " .. results.stderr)
else
print("Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded.")
end

if $module.cleanup ~= nil then
$module:cleanup()
end

return results
})
.eval::<Table>()?;

lua.load(chunk! {
if $module.cleanup ~= nil then
$module:cleanup()
end
})
.eval::<()>()?;

let ignore_exit_code = task
.get::<bool>("ignore_exit_code")
.unwrap_or_else(|_| defaults.get::<bool>("ignore_exit_code").unwrap());
Expand Down
14 changes: 11 additions & 3 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ pub fn regex_is_match(
Ok(re.is_match(&text.to_str()?))
}

pub fn hostname_display(host: &Table) -> String {
pub fn host_display(host: &Table) -> String {
let address = host.get::<String>("address").unwrap();

match host.get::<String>("name") {
Expand All @@ -187,6 +187,14 @@ pub fn hostname_display(host: &Table) -> String {
}
}

pub fn task_display(task: &Table) -> String {
let module = task.get::<Table>(1).unwrap();
match task.get::<String>("name") {
Ok(name) => name,
Err(_) => module.get::<String>("name").unwrap(),
}
}

// Tests
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -559,11 +567,11 @@ mod tests {
let host = lua.create_table().unwrap();
host.set("address", "192.168.1.1").unwrap();
host.set("name", "test").unwrap();
assert_eq!(hostname_display(&host), "test (192.168.1.1)");
assert_eq!(host_display(&host), "test (192.168.1.1)");

// Test without name
let host = lua.create_table().unwrap();
host.set("address", "10.0.0.1").unwrap();
assert_eq!(hostname_display(&host), "10.0.0.1");
assert_eq!(host_display(&host), "10.0.0.1");
}
}
15 changes: 10 additions & 5 deletions src/validator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use mlua::{chunk, Error::RuntimeError, Integer, Lua, Table, Value};
use mlua::{chunk, Error::RuntimeError, ExternalResult, Integer, Lua, Table, Value};

pub fn validate_host(lua: &Lua, host: Value) -> mlua::Result<Table> {
if !host.is_table() {
Expand Down Expand Up @@ -33,16 +33,19 @@ fn validate_port(_: &Lua, port: Value) -> mlua::Result<Integer> {
Ok(port.as_integer().unwrap())
}

pub fn validate_task(_: &Lua, task: Value) -> mlua::Result<Table> {
pub fn validate_task(lua: &Lua, task: Value) -> mlua::Result<Table> {
if !task.is_table() {
return Err(RuntimeError("Task is not a table.".to_string()));
}

if task.as_table().unwrap().get::<Value>(1)?.is_nil() {
let task = task.as_table().unwrap();
if task.get::<Value>(1)?.is_nil() {
return Err(RuntimeError("Task is invalid.".to_string()));
}

Ok(task.as_table().unwrap().to_owned())
validate_module(lua, task.get::<Value>(1)?).into_lua_err()?;

Ok(task.to_owned())
}

pub fn validate_module(lua: &Lua, module: Value) -> mlua::Result<Table> {
Expand Down Expand Up @@ -206,7 +209,9 @@ mod tests {
fn test_validate_task_valid() {
let lua = Lua::new();
let task = lua.create_table().unwrap();
task.set(1, "cmd").unwrap();
let module = lua.create_table().unwrap();
module.set("name", "cmd").unwrap();
task.set(1, module).unwrap();

let result = super::validate_task(&lua, mlua::Value::Table(task));
assert!(result.is_ok());
Expand Down
87 changes: 86 additions & 1 deletion tests/komando.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use komandan::setup_lua_env;
use mlua::{chunk, Integer, Lua, Table};
use std::io::Write;
use std::{env, io::Write};
use tempfile::NamedTempFile;

#[test]
Expand Down Expand Up @@ -82,6 +82,91 @@ fn test_komando_userauth_invalid_password() {
assert!(result.is_err());
}

#[test]
fn test_komando_use_default_user() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();

let result = lua
.load(chunk! {
komandan.set_defaults({
user = "usertest",
})

local hosts = {
address = "localhost",
private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519"
}

local task = {
komandan.modules.cmd({
cmd = "echo hello"
})
}

return komandan.komando(hosts, task)
})
.eval::<Table>();

assert!(result.is_ok());
}

#[test]
fn test_komando_use_default_user_from_env() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
env::set_var("USER", "usertest");

let result = lua
.load(chunk! {
local hosts = {
address = "localhost",
private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519",
}

local task = {
komandan.modules.cmd({
cmd = "echo hello"
})
}

return komandan.komando(hosts, task)
})
.eval::<Table>();

assert!(result.is_ok());
}

#[test]
fn test_komando_no_user_specified() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
env::remove_var("USER");

let result = lua
.load(chunk! {
local hosts = {
address = "localhost",
private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519",
}

local task = {
komandan.modules.cmd({
cmd = "echo hello"
})
}

return komandan.komando(hosts, task)
})
.eval::<Table>();

assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("No user specified for task"));
}

#[test]
fn test_komando_simple_cmd() {
let lua = Lua::new();
Expand Down

0 comments on commit 708b5b0

Please sign in to comment.