diff --git a/src/lib.rs b/src/lib.rs index c1b0aab..0a96359 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,9 +13,10 @@ use rustyline::DefaultEditor; use ssh::{ElevateMethod, Elevation, SSHAuthMethod, SSHSession}; use std::{env, path::Path}; use util::{ - dprint, filter_hosts, hostname_display, parse_hosts_json, regex_is_match, set_defaults, + dprint, filter_hosts, host_display, parse_hosts_json, regex_is_match, set_defaults, + task_display, }; -use validator::{validate_host, validate_module, validate_task}; +use validator::{validate_host, validate_task}; pub fn setup_lua_env(lua: &Lua) -> mlua::Result<()> { let args = Args::parse(); @@ -86,16 +87,10 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result { let host = lua.create_function(validate_host)?.call::
(&host)?; let task = lua.create_function(validate_task)?.call::
(&task)?; - let module = lua - .create_function(validate_module)? - .call::
(task.get::
(1).unwrap())?; + let module = task.get::
(1)?; - let host_display = hostname_display(&host); - - let task_display = match task.get::("name") { - Ok(name) => name, - Err(_) => module.get::("name")?, - }; + let host_display = host_display(&host); + let task_display = task_display(&task); let defaults = lua .globals() @@ -287,14 +282,13 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result
ssh.set_env(&key, &value); } - let module_clone = module.clone(); let results = lua .load(chunk! { print("Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...") - $module_clone.ssh = $ssh - $module_clone:run() + $module.ssh = $ssh + $module:run() - local results = $module_clone.ssh:get_session_results() + local results = $module.ssh:get_session_results() komandan.dprint(results.stdout) if results.exit_code ~= 0 then print("Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. results.exit_code .. ": " .. results.stderr) @@ -302,17 +296,14 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result
print("Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded.") end + if $module.cleanup ~= nil then + $module:cleanup() + end + return results }) .eval::
()?; - lua.load(chunk! { - if $module.cleanup ~= nil then - $module:cleanup() - end - }) - .eval::<()>()?; - let ignore_exit_code = task .get::("ignore_exit_code") .unwrap_or_else(|_| defaults.get::("ignore_exit_code").unwrap()); diff --git a/src/util.rs b/src/util.rs index 7714319..41a0edb 100644 --- a/src/util.rs +++ b/src/util.rs @@ -178,7 +178,7 @@ pub fn regex_is_match( Ok(re.is_match(&text.to_str()?)) } -pub fn hostname_display(host: &Table) -> String { +pub fn host_display(host: &Table) -> String { let address = host.get::("address").unwrap(); match host.get::("name") { @@ -187,6 +187,14 @@ pub fn hostname_display(host: &Table) -> String { } } +pub fn task_display(task: &Table) -> String { + let module = task.get::
(1).unwrap(); + match task.get::("name") { + Ok(name) => name, + Err(_) => module.get::("name").unwrap(), + } +} + // Tests #[cfg(test)] mod tests { @@ -559,11 +567,11 @@ mod tests { let host = lua.create_table().unwrap(); host.set("address", "192.168.1.1").unwrap(); host.set("name", "test").unwrap(); - assert_eq!(hostname_display(&host), "test (192.168.1.1)"); + assert_eq!(host_display(&host), "test (192.168.1.1)"); // Test without name let host = lua.create_table().unwrap(); host.set("address", "10.0.0.1").unwrap(); - assert_eq!(hostname_display(&host), "10.0.0.1"); + assert_eq!(host_display(&host), "10.0.0.1"); } } diff --git a/src/validator.rs b/src/validator.rs index f626747..08753c8 100644 --- a/src/validator.rs +++ b/src/validator.rs @@ -1,4 +1,4 @@ -use mlua::{chunk, Error::RuntimeError, Integer, Lua, Table, Value}; +use mlua::{chunk, Error::RuntimeError, ExternalResult, Integer, Lua, Table, Value}; pub fn validate_host(lua: &Lua, host: Value) -> mlua::Result
{ if !host.is_table() { @@ -33,16 +33,19 @@ fn validate_port(_: &Lua, port: Value) -> mlua::Result { Ok(port.as_integer().unwrap()) } -pub fn validate_task(_: &Lua, task: Value) -> mlua::Result
{ +pub fn validate_task(lua: &Lua, task: Value) -> mlua::Result
{ if !task.is_table() { return Err(RuntimeError("Task is not a table.".to_string())); } - if task.as_table().unwrap().get::(1)?.is_nil() { + let task = task.as_table().unwrap(); + if task.get::(1)?.is_nil() { return Err(RuntimeError("Task is invalid.".to_string())); } - Ok(task.as_table().unwrap().to_owned()) + validate_module(lua, task.get::(1)?).into_lua_err()?; + + Ok(task.to_owned()) } pub fn validate_module(lua: &Lua, module: Value) -> mlua::Result
{ @@ -206,7 +209,9 @@ mod tests { fn test_validate_task_valid() { let lua = Lua::new(); let task = lua.create_table().unwrap(); - task.set(1, "cmd").unwrap(); + let module = lua.create_table().unwrap(); + module.set("name", "cmd").unwrap(); + task.set(1, module).unwrap(); let result = super::validate_task(&lua, mlua::Value::Table(task)); assert!(result.is_ok()); diff --git a/tests/komando.rs b/tests/komando.rs index 19bb2c5..f6e32ff 100644 --- a/tests/komando.rs +++ b/tests/komando.rs @@ -1,6 +1,6 @@ use komandan::setup_lua_env; use mlua::{chunk, Integer, Lua, Table}; -use std::io::Write; +use std::{env, io::Write}; use tempfile::NamedTempFile; #[test] @@ -82,6 +82,91 @@ fn test_komando_userauth_invalid_password() { assert!(result.is_err()); } +#[test] +fn test_komando_use_default_user() { + let lua = Lua::new(); + setup_lua_env(&lua).unwrap(); + + let result = lua + .load(chunk! { + komandan.set_defaults({ + user = "usertest", + }) + + local hosts = { + address = "localhost", + private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519" + } + + local task = { + komandan.modules.cmd({ + cmd = "echo hello" + }) + } + + return komandan.komando(hosts, task) + }) + .eval::
(); + + assert!(result.is_ok()); +} + +#[test] +fn test_komando_use_default_user_from_env() { + let lua = Lua::new(); + setup_lua_env(&lua).unwrap(); + env::set_var("USER", "usertest"); + + let result = lua + .load(chunk! { + local hosts = { + address = "localhost", + private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519", + } + + local task = { + komandan.modules.cmd({ + cmd = "echo hello" + }) + } + + return komandan.komando(hosts, task) + }) + .eval::
(); + + assert!(result.is_ok()); +} + +#[test] +fn test_komando_no_user_specified() { + let lua = Lua::new(); + setup_lua_env(&lua).unwrap(); + env::remove_var("USER"); + + let result = lua + .load(chunk! { + local hosts = { + address = "localhost", + private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519", + } + + local task = { + komandan.modules.cmd({ + cmd = "echo hello" + }) + } + + return komandan.komando(hosts, task) + }) + .eval::
(); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("No user specified for task")); +} + #[test] fn test_komando_simple_cmd() { let lua = Lua::new();