diff --git a/src/lib.rs b/src/lib.rs
index c1b0aab..0a96359 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -13,9 +13,10 @@ use rustyline::DefaultEditor;
use ssh::{ElevateMethod, Elevation, SSHAuthMethod, SSHSession};
use std::{env, path::Path};
use util::{
- dprint, filter_hosts, hostname_display, parse_hosts_json, regex_is_match, set_defaults,
+ dprint, filter_hosts, host_display, parse_hosts_json, regex_is_match, set_defaults,
+ task_display,
};
-use validator::{validate_host, validate_module, validate_task};
+use validator::{validate_host, validate_task};
pub fn setup_lua_env(lua: &Lua) -> mlua::Result<()> {
let args = Args::parse();
@@ -86,16 +87,10 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> {
async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result
{
let host = lua.create_function(validate_host)?.call::(&host)?;
let task = lua.create_function(validate_task)?.call::(&task)?;
- let module = lua
- .create_function(validate_module)?
- .call::(task.get::(1).unwrap())?;
+ let module = task.get::(1)?;
- let host_display = hostname_display(&host);
-
- let task_display = match task.get::("name") {
- Ok(name) => name,
- Err(_) => module.get::("name")?,
- };
+ let host_display = host_display(&host);
+ let task_display = task_display(&task);
let defaults = lua
.globals()
@@ -287,14 +282,13 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result
ssh.set_env(&key, &value);
}
- let module_clone = module.clone();
let results = lua
.load(chunk! {
print("Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...")
- $module_clone.ssh = $ssh
- $module_clone:run()
+ $module.ssh = $ssh
+ $module:run()
- local results = $module_clone.ssh:get_session_results()
+ local results = $module.ssh:get_session_results()
komandan.dprint(results.stdout)
if results.exit_code ~= 0 then
print("Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. results.exit_code .. ": " .. results.stderr)
@@ -302,17 +296,14 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result
print("Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded.")
end
+ if $module.cleanup ~= nil then
+ $module:cleanup()
+ end
+
return results
})
.eval::()?;
- lua.load(chunk! {
- if $module.cleanup ~= nil then
- $module:cleanup()
- end
- })
- .eval::<()>()?;
-
let ignore_exit_code = task
.get::("ignore_exit_code")
.unwrap_or_else(|_| defaults.get::("ignore_exit_code").unwrap());
diff --git a/src/util.rs b/src/util.rs
index 7714319..41a0edb 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -178,7 +178,7 @@ pub fn regex_is_match(
Ok(re.is_match(&text.to_str()?))
}
-pub fn hostname_display(host: &Table) -> String {
+pub fn host_display(host: &Table) -> String {
let address = host.get::("address").unwrap();
match host.get::("name") {
@@ -187,6 +187,14 @@ pub fn hostname_display(host: &Table) -> String {
}
}
+pub fn task_display(task: &Table) -> String {
+ let module = task.get::(1).unwrap();
+ match task.get::("name") {
+ Ok(name) => name,
+ Err(_) => module.get::("name").unwrap(),
+ }
+}
+
// Tests
#[cfg(test)]
mod tests {
@@ -559,11 +567,11 @@ mod tests {
let host = lua.create_table().unwrap();
host.set("address", "192.168.1.1").unwrap();
host.set("name", "test").unwrap();
- assert_eq!(hostname_display(&host), "test (192.168.1.1)");
+ assert_eq!(host_display(&host), "test (192.168.1.1)");
// Test without name
let host = lua.create_table().unwrap();
host.set("address", "10.0.0.1").unwrap();
- assert_eq!(hostname_display(&host), "10.0.0.1");
+ assert_eq!(host_display(&host), "10.0.0.1");
}
}
diff --git a/src/validator.rs b/src/validator.rs
index f626747..08753c8 100644
--- a/src/validator.rs
+++ b/src/validator.rs
@@ -1,4 +1,4 @@
-use mlua::{chunk, Error::RuntimeError, Integer, Lua, Table, Value};
+use mlua::{chunk, Error::RuntimeError, ExternalResult, Integer, Lua, Table, Value};
pub fn validate_host(lua: &Lua, host: Value) -> mlua::Result {
if !host.is_table() {
@@ -33,16 +33,19 @@ fn validate_port(_: &Lua, port: Value) -> mlua::Result {
Ok(port.as_integer().unwrap())
}
-pub fn validate_task(_: &Lua, task: Value) -> mlua::Result {
+pub fn validate_task(lua: &Lua, task: Value) -> mlua::Result {
if !task.is_table() {
return Err(RuntimeError("Task is not a table.".to_string()));
}
- if task.as_table().unwrap().get::(1)?.is_nil() {
+ let task = task.as_table().unwrap();
+ if task.get::(1)?.is_nil() {
return Err(RuntimeError("Task is invalid.".to_string()));
}
- Ok(task.as_table().unwrap().to_owned())
+ validate_module(lua, task.get::(1)?).into_lua_err()?;
+
+ Ok(task.to_owned())
}
pub fn validate_module(lua: &Lua, module: Value) -> mlua::Result {
@@ -206,7 +209,9 @@ mod tests {
fn test_validate_task_valid() {
let lua = Lua::new();
let task = lua.create_table().unwrap();
- task.set(1, "cmd").unwrap();
+ let module = lua.create_table().unwrap();
+ module.set("name", "cmd").unwrap();
+ task.set(1, module).unwrap();
let result = super::validate_task(&lua, mlua::Value::Table(task));
assert!(result.is_ok());
diff --git a/tests/komando.rs b/tests/komando.rs
index 19bb2c5..f6e32ff 100644
--- a/tests/komando.rs
+++ b/tests/komando.rs
@@ -1,6 +1,6 @@
use komandan::setup_lua_env;
use mlua::{chunk, Integer, Lua, Table};
-use std::io::Write;
+use std::{env, io::Write};
use tempfile::NamedTempFile;
#[test]
@@ -82,6 +82,91 @@ fn test_komando_userauth_invalid_password() {
assert!(result.is_err());
}
+#[test]
+fn test_komando_use_default_user() {
+ let lua = Lua::new();
+ setup_lua_env(&lua).unwrap();
+
+ let result = lua
+ .load(chunk! {
+ komandan.set_defaults({
+ user = "usertest",
+ })
+
+ local hosts = {
+ address = "localhost",
+ private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519"
+ }
+
+ local task = {
+ komandan.modules.cmd({
+ cmd = "echo hello"
+ })
+ }
+
+ return komandan.komando(hosts, task)
+ })
+ .eval::();
+
+ assert!(result.is_ok());
+}
+
+#[test]
+fn test_komando_use_default_user_from_env() {
+ let lua = Lua::new();
+ setup_lua_env(&lua).unwrap();
+ env::set_var("USER", "usertest");
+
+ let result = lua
+ .load(chunk! {
+ local hosts = {
+ address = "localhost",
+ private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519",
+ }
+
+ local task = {
+ komandan.modules.cmd({
+ cmd = "echo hello"
+ })
+ }
+
+ return komandan.komando(hosts, task)
+ })
+ .eval::();
+
+ assert!(result.is_ok());
+}
+
+#[test]
+fn test_komando_no_user_specified() {
+ let lua = Lua::new();
+ setup_lua_env(&lua).unwrap();
+ env::remove_var("USER");
+
+ let result = lua
+ .load(chunk! {
+ local hosts = {
+ address = "localhost",
+ private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519",
+ }
+
+ local task = {
+ komandan.modules.cmd({
+ cmd = "echo hello"
+ })
+ }
+
+ return komandan.komando(hosts, task)
+ })
+ .eval::();
+
+ assert!(result.is_err());
+ assert!(result
+ .unwrap_err()
+ .to_string()
+ .contains("No user specified for task"));
+}
+
#[test]
fn test_komando_simple_cmd() {
let lua = Lua::new();