diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index a4c1ba7..6511993 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -16,12 +16,14 @@ jobs: env: USER_TEST: usertest run: | - apt update && apt install -y openssh-server + apt update && apt install -y openssh-server sudo service ssh start mkdir ${HOME}/.ssh ssh-keyscan -H localhost | tee -a ${HOME}/.ssh/known_hosts ssh-keygen -t ed25519 -N "" -f ${HOME}/.ssh/id_ed25519 useradd -m -d /home/${USER_TEST} -N ${USER_TEST} + usermod -a -G sudo ${USER_TEST} + echo '%sudo ALL=(ALL) NOPASSWD:ALL' | tee /etc/sudoers.d/sudo-nopasswd su ${USER_TEST} -c 'mkdir /home/${USER_TEST}/.ssh' cat ${HOME}/.ssh/id_ed25519.pub | su ${USER_TEST} -c 'tee -a /home/${USER_TEST}/.ssh/authorized_keys' chmod 700 /home/${USER_TEST}/.ssh diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d5a187b..07c10de 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -39,6 +39,8 @@ jobs: ssh-keyscan -H localhost | tee -a ${HOME}/.ssh/known_hosts ssh-keygen -t ed25519 -N "" -f ${HOME}/.ssh/id_ed25519 sudo useradd -m -d /home/${USER_TEST} -N ${USER_TEST} + sudo usermod -a -G sudo ${USER_TEST} + echo '%sudo ALL=(ALL) NOPASSWD:ALL' | sudo tee /etc/sudoers.d/sudo-nopasswd sudo -u ${USER_TEST} mkdir /home/${USER_TEST}/.ssh cat ${HOME}/.ssh/id_ed25519.pub | sudo -u ${USER_TEST} tee -a /home/${USER_TEST}/.ssh/authorized_keys sudo chmod 700 /home/${USER_TEST}/.ssh diff --git a/Cargo.lock b/Cargo.lock index b06dcde..39a43cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "autocfg" @@ -129,9 +129,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "clap" -version = "4.5.21" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", "clap_derive", @@ -139,9 +139,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.21" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstream", "anstyle", @@ -158,14 +158,14 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "clipboard-win" @@ -347,7 +347,6 @@ name = "komandan" version = "0.1.0" dependencies = [ "anyhow", - "base64", "clap", "http-klien", "minijinja", @@ -452,9 +451,9 @@ dependencies = [ [[package]] name = "mlua" -version = "0.10.1" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ae9546e4a268c309804e8bbb7526e31cbfdedca7cd60ac1b987d0b212e0d876" +checksum = "9ea43c3ffac2d0798bd7128815212dd78c98316b299b7a902dabef13dc7b6b8d" dependencies = [ "anyhow", "bstr", @@ -471,9 +470,9 @@ dependencies = [ [[package]] name = "mlua-sys" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efa6bf1a64f06848749b7e7727417f4ec2121599e2a10ef0a8a3888b0e9a5a0d" +checksum = "63a11d485edf0f3f04a508615d36c7d50d299cf61a7ee6d3e2530651e0a31771" dependencies = [ "cc", "cfg-if", @@ -484,17 +483,17 @@ dependencies = [ [[package]] name = "mlua_derive" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cfc5faa2e0d044b3f5f0879be2920e0a711c97744c42cf1c295cb183668933e" +checksum = "870d71c172fcf491c6b5fb4c04160619a2ee3e5a42a1402269c66bcbf1dd4deb" dependencies = [ "itertools", "once_cell", - "proc-macro-error", + "proc-macro-error2", "proc-macro2", "quote", "regex", - "syn 2.0.89", + "syn", ] [[package]] @@ -556,7 +555,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] [[package]] @@ -654,27 +653,25 @@ dependencies = [ ] [[package]] -name = "proc-macro-error" -version = "1.0.4" +name = "proc-macro-error-attr2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" dependencies = [ - "proc-macro-error-attr", "proc-macro2", "quote", - "syn 1.0.109", - "version_check", ] [[package]] -name = "proc-macro-error-attr" -version = "1.0.4" +name = "proc-macro-error2" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" dependencies = [ + "proc-macro-error-attr2", "proc-macro2", "quote", - "version_check", + "syn", ] [[package]] @@ -857,9 +854,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] @@ -876,20 +873,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.134" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" dependencies = [ "itoa", "memchr", @@ -927,16 +924,6 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.89" @@ -997,12 +984,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1149,5 +1130,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] diff --git a/Cargo.toml b/Cargo.toml index f7e5348..b079b2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,11 @@ edition = "2021" vendored-openssl = ["http-klien/vendored-openssl", "ssh2/vendored-openssl"] [dependencies] -anyhow = "1.0.93" -base64 = "0.22.1" -clap = { version = "4.5.21", features = ["derive"] } +anyhow = "1.0.95" +clap = { version = "4.5.23", features = ["derive"] } http-klien = { git = "https://github.com/hahnavi/http-klien-rs", branch = "main" } minijinja = "2.5.0" -mlua = { version = "0.10.1", features = [ +mlua = { version = "0.10.2", features = [ "anyhow", "luajit", "macros", @@ -25,8 +24,8 @@ rand = "0.8.5" rayon = "1.10.0" regex = "1.11.1" rustyline = "15.0.0" -serde = { version = "1.0.216", features = ["derive"] } -serde_json = "1.0.133" +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.134" ssh2 = { version = "0.9.4" } [dev-dependencies] diff --git a/examples/postgresql_replication_setup_ubuntu.lua b/examples/postgresql_replication_setup_ubuntu.lua index e44220d..2a2079c 100644 --- a/examples/postgresql_replication_setup_ubuntu.lua +++ b/examples/postgresql_replication_setup_ubuntu.lua @@ -33,8 +33,11 @@ local tasks_postgresql_install = { }, { name = "Add PostgreSQL repository", + env = { + YES = "yes", + }, komandan.modules.cmd({ - cmd = "YES=yes /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh", + cmd = "/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh", }), }, { @@ -102,8 +105,10 @@ end local tasks_setup_primary = { { name = "Create replication user", - komandan.modules.cmd({ - cmd = "psql -c \"DO \\$\\$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_user WHERE usename = '" .. os.getenv("REPLICATOR_USER") .. "') THEN CREATE USER " .. os.getenv("REPLICATOR_USER") .. " WITH REPLICATION ENCRYPTED PASSWORD '" .. os.getenv("REPLICATOR_PASSWORD") .. "'; END IF; END \\$\\$;\"", + komandan.modules.postgresql_user({ + name = os.getenv("REPLICATOR_USER"), + password = os.getenv("REPLICATOR_PASSWORD"), + role_attr_flags = "REPLICATION", }), as_user = "postgres", }, diff --git a/src/args.rs b/src/args.rs index 8ae032b..74de978 100644 --- a/src/args.rs +++ b/src/args.rs @@ -12,6 +12,10 @@ pub struct Args { #[arg(short = 'e')] pub chunk: Option, + /// Dry run mode + #[arg(short, long)] + pub dry_run: bool, + /// Enter interactive mode after executing 'script'. #[arg(short, long)] pub interactive: bool, diff --git a/src/defaults.rs b/src/defaults.rs index 2b01164..a92f967 100644 --- a/src/defaults.rs +++ b/src/defaults.rs @@ -1,10 +1,11 @@ +use anyhow::{Error, Result}; use mlua::UserData; use std::{ collections::HashMap, sync::{Arc, OnceLock, RwLock}, }; -static GLOBAL_STATE: OnceLock = OnceLock::new(); +static GLOBAL_DEFAULTS: OnceLock = OnceLock::new(); #[derive(Clone)] pub struct Defaults { @@ -23,16 +24,16 @@ pub struct Defaults { } impl Defaults { - pub fn new() -> Self { + pub fn new() -> Result { let env = Arc::new(RwLock::new(HashMap::new())); match env.write() { Ok(mut env) => { env.insert("DEBIAN_FRONTEND".to_string(), "noninteractive".to_string()); } - Err(_) => {} + Err(_) => return Err(Error::msg("Failed to acquire write lock".to_string())), } - Self { + Ok(Self { port: Arc::new(RwLock::new(22)), user: Arc::new(RwLock::new(None)), private_key_file: Arc::new(RwLock::new(None)), @@ -44,15 +45,17 @@ impl Defaults { as_user: Arc::new(RwLock::new(None)), known_hosts_file: Arc::new(RwLock::new(format!( "{}/.ssh/known_hosts", - std::env::var("HOME").unwrap() + std::env::var("HOME").unwrap_or("~".to_string()) ))), host_key_check: Arc::new(RwLock::new(true)), env, - } + }) } - pub fn global() -> Self { - GLOBAL_STATE.get_or_init(|| Defaults::new()).clone() + pub fn global() -> Result { + Ok(GLOBAL_DEFAULTS + .get_or_init(|| Defaults::new().unwrap()) + .clone()) } } @@ -61,11 +64,9 @@ impl UserData for Defaults { methods.add_method("get_port", |_, this, ()| -> mlua::Result { match this.port.read() { Ok(port) => Ok(*port), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -75,22 +76,18 @@ impl UserData for Defaults { *port = new_port; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }); methods.add_method("get_user", |_, this, ()| -> mlua::Result> { match this.user.read() { Ok(user) => Ok(user.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -102,11 +99,9 @@ impl UserData for Defaults { *user = new_user; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -116,11 +111,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.private_key_file.read() { Ok(private_key_file) => Ok(private_key_file.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -133,11 +126,9 @@ impl UserData for Defaults { *private_key_file = new_private_key_file; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -147,11 +138,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.private_key_pass.read() { Ok(private_key_pass) => Ok(private_key_pass.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -164,11 +153,9 @@ impl UserData for Defaults { *private_key_pass = new_private_key_pass; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -178,11 +165,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.password.read() { Ok(password) => Ok(password.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -195,11 +180,9 @@ impl UserData for Defaults { *password = new_password; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -209,11 +192,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result { match this.ignore_exit_code.read() { Ok(ignore_exit_code) => Ok(*ignore_exit_code), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -226,11 +207,9 @@ impl UserData for Defaults { *ignore_exit_code = new_ignore_exit_code; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -238,11 +217,9 @@ impl UserData for Defaults { methods.add_method("get_elevate", |_, this, ()| -> mlua::Result { match this.elevate.read() { Ok(elevate) => Ok(*elevate), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -254,11 +231,9 @@ impl UserData for Defaults { *elevate = new_elevate; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -268,11 +243,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result { match this.elevation_method.read() { Ok(elevation_method) => Ok(elevation_method.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -285,11 +258,9 @@ impl UserData for Defaults { *elevation_method = new_elevation_method; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -299,11 +270,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.as_user.read() { Ok(as_user) => Ok(as_user.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -316,11 +285,9 @@ impl UserData for Defaults { *as_user = new_as_user; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -330,11 +297,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result { match this.known_hosts_file.read() { Ok(known_hosts_file) => Ok(known_hosts_file.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -347,11 +312,9 @@ impl UserData for Defaults { *known_hosts_file = new_known_hosts_file; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -359,11 +322,9 @@ impl UserData for Defaults { methods.add_method("get_host_key_check", |_, this, ()| -> mlua::Result { match this.host_key_check.read() { Ok(host_key_check) => Ok(*host_key_check), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -375,11 +336,9 @@ impl UserData for Defaults { *host_key_check = new_host_key_check; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -433,8 +392,8 @@ mod tests { use super::*; #[test] - fn test_defaults_new() { - let defaults = Defaults::new(); + fn test_defaults_new() -> Result<()> { + let defaults = Defaults::new()?; // Test default values assert_eq!(*defaults.port.read().unwrap(), 22); @@ -442,11 +401,11 @@ mod tests { assert_eq!(*defaults.private_key_file.read().unwrap(), None); assert_eq!(*defaults.private_key_pass.read().unwrap(), None); assert_eq!(*defaults.password.read().unwrap(), None); - assert_eq!(*defaults.ignore_exit_code.read().unwrap(), false); - assert_eq!(*defaults.elevate.read().unwrap(), false); + assert!(!(*defaults.ignore_exit_code.read().unwrap())); + assert!(!(*defaults.elevate.read().unwrap())); assert_eq!(*defaults.elevation_method.read().unwrap(), "sudo"); assert_eq!(*defaults.as_user.read().unwrap(), None); - assert_eq!(*defaults.host_key_check.read().unwrap(), true); + assert!(*defaults.host_key_check.read().unwrap()); // Test default environment variables let env = defaults.env.read().unwrap(); @@ -454,24 +413,28 @@ mod tests { env.get("DEBIAN_FRONTEND"), Some(&"noninteractive".to_string()) ); + + Ok(()) } #[test] - fn test_global_singleton() { - let defaults1 = Defaults::global(); - let defaults2 = Defaults::global(); + fn test_global_singleton() -> Result<()> { + let defaults1 = Defaults::global()?; + let defaults2 = Defaults::global()?; // Modify a value using the first instance *defaults1.port.write().unwrap() = 2222; // Check if the change is reflected in the second instance assert_eq!(*defaults2.port.read().unwrap(), 2222); + + Ok(()) } #[test] - fn test_lua_interface() -> mlua::Result<()> { + fn test_lua_interface() -> Result<()> { let lua = mlua::Lua::new(); - let defaults = Defaults::new(); + let defaults = Defaults::new()?; // Register the defaults instance with Lua lua.globals().set("defaults", defaults.clone())?; diff --git a/src/komando.rs b/src/komando.rs new file mode 100644 index 0000000..b80761c --- /dev/null +++ b/src/komando.rs @@ -0,0 +1,567 @@ +use std::collections::HashMap; +use std::env; + +use clap::Parser; +use mlua::{chunk, IntoLua, LuaSerdeExt}; +use mlua::{Error::RuntimeError, FromLua, Integer, Lua, Table, Value}; +use rayon::prelude::*; + +use crate::args::Args; +use crate::create_lua; +use crate::defaults::Defaults; +use crate::models::{Host, KomandoResult, Task}; +use crate::report::{insert_record, TaskStatus}; +use crate::ssh::{Elevation, ElevationMethod, SSHAuthMethod, SSHSession}; +use crate::util::{host_display, task_display}; +use crate::validator::{validate_host, validate_task}; + +pub fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result { + let host = lua.create_function(validate_host)?.call::
(&host)?; + let task = lua.create_function(validate_task)?.call::
(&task)?; + let module = task.get::
(1)?; + + let host_display = host_display(&host); + let task_display = task_display(&task); + + let defaults = Defaults::global()?; + + let (user, ssh_auth_method) = get_auth_config(&host, &task)?; + let elevation = get_elevation_config(&host, &task)?; + + let default_port = match defaults.port.read() { + Ok(port) => port, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let port = host.get::("port").unwrap_or(*default_port as i64) as u16; + + let mut ssh = create_ssh_session(&host)?; + ssh.elevation = elevation; + + ssh.connect( + host.get::("address")?.as_str(), + port, + &user, + ssh_auth_method, + )?; + + setup_environment(&mut ssh, &host, &task)?; + + let result = execute_task(lua, &module, ssh, &task_display, &host_display)?; + + let default_ignore_exit_code = match defaults.ignore_exit_code.read() { + Ok(ignore_exit_code) => ignore_exit_code, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let ignore_exit_code = task + .get::("ignore_exit_code") + .unwrap_or(*default_ignore_exit_code); + + let exit_code = result.get::("exit_code")?; + + if exit_code != 0 && !ignore_exit_code { + return Err(RuntimeError("Failed to run task.".to_string())); + } + + let task_status = if exit_code != 0 { + TaskStatus::Failed + } else { + match result.get::("changed")? { + true => TaskStatus::Changed, + false => TaskStatus::OK, + } + }; + + insert_record(task_display, host_display, task_status); + + Ok(result) +} + +pub fn komando_parallel_tasks(lua: &Lua, (host, tasks): (Value, Value)) -> mlua::Result
{ + let host = Host::from_lua(host, lua)?; + let mut tasks_hm = HashMap::::new(); + for pair in tasks.as_table().unwrap().pairs::() { + let (key, value): (u32, Value) = pair?; + let task = Task::from_lua(value, lua)?; + tasks_hm.insert(key, task); + } + + let results: HashMap = tasks_hm + .par_iter() + .map(|(i, task)| { + let lua = create_lua().unwrap(); + let host = host.clone().into_lua(&lua).unwrap(); + let task = task.clone().into_lua(&lua).unwrap(); + let result = komando(&lua, (host, task)).unwrap(); + + ( + *i, + lua.from_value::(Value::Table(result)) + .unwrap(), + ) + }) + .collect::>(); + + let results_table = lua.create_table()?; + results.iter().for_each(|(i, result)| { + results_table + .set(*i, lua.to_value(result).unwrap()) + .unwrap(); + }); + + Ok(results_table) +} + +pub fn komando_parallel_hosts(lua: &Lua, (hosts, task): (Value, Value)) -> mlua::Result
{ + let task = Task::from_lua(task, lua)?; + let mut hosts_hm = HashMap::::new(); + for pair in hosts.as_table().unwrap().pairs::() { + let (key, value): (u32, Value) = pair?; + let host = Host::from_lua(value, lua)?; + hosts_hm.insert(key, host); + } + + let results: HashMap = hosts_hm + .par_iter() + .map(|(i, host)| { + let lua = create_lua().unwrap(); + let host = host.clone().into_lua(&lua).unwrap(); + let task = task.clone().into_lua(&lua).unwrap(); + let result = komando(&lua, (host, task)).unwrap(); + + ( + *i, + lua.from_value::(Value::Table(result)) + .unwrap(), + ) + }) + .collect::>(); + + let results_table = lua.create_table()?; + results.iter().for_each(|(i, result)| { + results_table + .set(*i, lua.to_value(result).unwrap()) + .unwrap(); + }); + + Ok(results_table) +} + +fn get_user(host: &Table, task: &Table) -> mlua::Result { + let defaults = Defaults::global()?; + let default_user = match defaults.user.read() { + Ok(user) => user, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + let user = match host.get::("user") { + Ok(user) => user, + Err(_) => match *default_user { + Some(ref user) => user.clone(), + None => match env::var("USER") { + Ok(user) => user, + Err(_) => { + return Err(RuntimeError(format!( + "No user specified for task '{}'.", + task_display(task) + ))) + } + }, + }, + }; + + Ok(user) +} + +fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthMethod)> { + let host_display = host_display(host); + let task_display = task_display(task); + + let user = get_user(host, task)?; + + let defaults = Defaults::global()?; + + let default_private_key_file = match defaults.private_key_file.read() { + Ok(private_key_file) => private_key_file, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let default_private_key_pass = match defaults.private_key_pass.read() { + Ok(private_key_pass) => private_key_pass, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let default_password = match defaults.password.read() { + Ok(password) => password, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let ssh_auth_method = match host.get::("private_key_file") { + Ok(private_key_file) => SSHAuthMethod::PublicKey { + private_key: private_key_file, + passphrase: match host.get::("private_key_pass") { + Ok(passphrase) => Some(passphrase), + Err(_) => (*default_private_key_pass).clone(), + }, + }, + Err(_) => match *default_private_key_file { + Some(ref private_key_file) => SSHAuthMethod::PublicKey { + private_key: private_key_file.clone(), + passphrase: match host.get::("private_key_pass") { + Ok(passphrase) => Some(passphrase), + Err(_) => (*default_private_key_pass).clone(), + }, + }, + None => match host.get::("password") { + Ok(password) => SSHAuthMethod::Password(password), + Err(_) => match *default_password { + Some(ref password) => SSHAuthMethod::Password(password.clone()), + None => { + return Err(RuntimeError(format!( + "No authentication method specified for task '{}' on host '{}'.", + task_display, host_display + ))) + } + }, + }, + }, + }; + + Ok((user, ssh_auth_method)) +} + +fn get_elevation_config(host: &Table, task: &Table) -> mlua::Result { + let defaults = Defaults::global()?; + + let default_elevate = match defaults.elevate.read() { + Ok(elevate) => elevate, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let task_elevate = task.get::("elevate")?; + let host_elevate = host.get::("elevate")?; + + let elevate = if !task_elevate.is_nil() { + task_elevate.as_boolean().unwrap() + } else if !host_elevate.is_nil() { + host_elevate.as_boolean().unwrap() + } else { + *default_elevate + }; + + if !elevate { + return Ok(Elevation { + method: ElevationMethod::None, + as_user: None, + }); + } + + let default_elevation_method = match defaults.elevation_method.read() { + Ok(elevation_method) => elevation_method, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let elevation_method_str = task.get::("elevation_method").unwrap_or( + host.get::("elevation_method") + .unwrap_or(default_elevation_method.clone()), + ); + + let elevation_method = match elevation_method_str.as_str() { + "none" => Ok(ElevationMethod::None), + "sudo" => Ok(ElevationMethod::Sudo), + "su" => Ok(ElevationMethod::Su), + _ => Err(RuntimeError(format!( + "Unsupported elevation method: {}", + elevation_method_str + ))), + }; + + let default_as_user = match defaults.as_user.read() { + Ok(as_user) => as_user, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let as_user = task.get::>("as_user").unwrap_or( + host.get::>("as_user") + .unwrap_or(default_as_user.clone()), + ); + + Ok(Elevation { + method: elevation_method?, + as_user, + }) +} + +fn create_ssh_session(host: &Table) -> mlua::Result { + let defaults = Defaults::global()?; + let mut ssh = SSHSession::new()?; + + let default_host_key_check = match defaults.host_key_check.read() { + Ok(host_key_check) => host_key_check, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let host_key_check = match host.get::("host_key_check") { + Ok(host_key_check) => match host_key_check { + Value::Nil => *default_host_key_check, + Value::Boolean(false) => false, + _ => true, + }, + Err(_) => true, + }; + + let default_known_hosts_file = match defaults.known_hosts_file.read() { + Ok(known_hosts_file) => known_hosts_file, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + if host_key_check { + ssh.known_hosts_file = match host.get::("known_hosts_file") { + Ok(known_hosts_file) => Some(known_hosts_file), + Err(_) => Some(default_known_hosts_file.clone()), + }; + } + + Ok(ssh) +} + +fn setup_environment(ssh: &mut SSHSession, host: &Table, task: &Table) -> mlua::Result<()> { + let defaults = Defaults::global()?; + + let default_env = match defaults.env.read() { + Ok(env) => env, + Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), + }; + + let env_host = host.get::>("env")?; + let env_task = task.get::>("env")?; + + for (key, value) in default_env.clone() { + ssh.set_env(&key, &value); + } + + if env_host.is_some() { + for pair in env_host.unwrap().pairs() { + let (key, value): (String, String) = pair?; + ssh.set_env(&key, &value); + } + } + + if env_task.is_some() { + for pair in env_task.unwrap().pairs() { + let (key, value): (String, String) = pair?; + ssh.set_env(&key, &value); + } + } + + Ok(()) +} + +fn execute_task( + lua: &Lua, + module: &Table, + ssh: SSHSession, + task_display: &str, + host_display: &str, +) -> mlua::Result
{ + let dry_run = Args::parse().dry_run; + + lua.load(chunk! { + print(">> Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...") + $module.ssh = $ssh + + if $dry_run then + if $module.dry_run ~= nil then + $module:dry_run() + else + print("[[ Task '" .. $task_display .. "' on host '" .. $host_display .."' does not support dry-run. Assuming 'changed' is true. ]]") + $module.ssh:set_changed(true) + end + else + $module:run() + end + + local result = $module.ssh:get_session_result() + komandan.dprint(result.stdout) + if result.exit_code ~= 0 then + print(">> Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. result.exit_code .. ": " .. result.stderr) + else + local state = "[OK]" + if result.changed then + state = "[Changed]" + end + print(">> Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded. " .. state) + end + + if $module.cleanup ~= nil then + $module:cleanup() + end + + return result + }) + .set_name("execute_task") + .eval::
() +} + +// Tests +#[cfg(test)] +mod tests { + use anyhow::Result; + + use super::*; + + #[test] + fn test_get_auth_config() -> Result<()> { + let lua = create_lua()?; + let host = lua.create_table()?; + + // Test with user in host + host.set("address", "localhost")?; + host.set("user", "testuser")?; + host.set("private_key_file", "/path/to/key")?; + + let module_params = lua.create_table()?; + module_params.set("cmd", "echo test")?; + let module = lua + .load(chunk! { + return komandan.modules.cmd($module_params) + }) + .eval::
()?; + let task = lua.create_table()?; + task.set(1, module)?; + + let (user, auth) = get_auth_config(&host, &task)?; + assert_eq!(user, "testuser"); + match auth { + SSHAuthMethod::PublicKey { + private_key, + passphrase, + } => { + assert_eq!(private_key, "/path/to/key"); + assert!(passphrase.is_none()); + } + _ => panic!("Expected PublicKey authentication"), + } + + // Test with password auth + host.set("private_key_file", Value::Nil)?; + host.set("password", "testpass")?; + let (_, auth) = get_auth_config(&host, &task)?; + match auth { + SSHAuthMethod::Password(pass) => assert_eq!(pass, "testpass"), + _ => panic!("Expected Password authentication"), + } + + // Test with no authentication method + host.set("password", Value::Nil)?; + let result = get_auth_config(&host, &task); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn test_get_elevation_config() -> Result<()> { + let lua = create_lua()?; + let host = lua.create_table()?; + let task = lua.create_table()?; + + // Test with no elevation + let elevation = get_elevation_config(&host, &task)?; + assert!(matches!( + elevation, + Elevation { + method: ElevationMethod::None, + as_user: None + } + )); + + // Test with elevation from task + task.set("elevate", true)?; + let elevation = get_elevation_config(&host, &task)?; + assert!(matches!( + elevation, + Elevation { + method: ElevationMethod::Sudo, + as_user: None + } + )); + + // Test with custom elevation method + task.set("elevation_method", "su")?; + let elevation = get_elevation_config(&host, &task)?; + assert!(matches!( + elevation, + Elevation { + method: ElevationMethod::Su, + as_user: None + } + )); + + // Test invalid elevation method + task.set("elevation_method", "invalid")?; + assert!(get_elevation_config(&host, &task).is_err()); + + Ok(()) + } + + #[test] + fn test_setup_ssh_session() -> Result<()> { + let lua = create_lua()?; + let host = lua.create_table()?; + host.set("address", "localhost")?; + + // Test with default settings + let ssh = create_ssh_session(&host)?; + assert!(ssh.known_hosts_file.is_some()); + + // Test with host key check disabled + host.set("host_key_check", false)?; + let ssh = create_ssh_session(&host)?; + assert!(ssh.known_hosts_file.is_none()); + + // Test with custom known_hosts file + host.set("known_hosts_file", "/path/to/known_hosts")?; + host.set("host_key_check", true)?; + let ssh = create_ssh_session(&host)?; + assert_eq!(ssh.known_hosts_file.unwrap(), "/path/to/known_hosts"); + + // Test with known_hosts from defaults + host.set("known_hosts_file", Value::Nil)?; + lua.load(chunk! { + komandan.defaults:set_known_hosts_file("/default/known_hosts") + }) + .exec()?; + let ssh = create_ssh_session(&host)?; + assert_eq!(ssh.known_hosts_file.unwrap(), "/default/known_hosts"); + + Ok(()) + } + + #[test] + fn test_setup_environment() -> Result<()> { + let lua = create_lua()?; + let mut ssh = SSHSession::new()?; + let defaults = lua.create_table()?; + let host = lua.create_table()?; + let task = lua.create_table()?; + + // Test with environment variables at all levels + let env_defaults = lua.create_table()?; + env_defaults.set("DEFAULT_VAR", "default_value")?; + defaults.set("env", env_defaults)?; + + let env_host = lua.create_table()?; + env_host.set("HOST_VAR", "host_value")?; + env_host.set("DEFAULT_VAR", "overridden_value")?; // Override default + host.set("env", env_host)?; + + let env_task = lua.create_table()?; + env_task.set("TASK_VAR", "task_value")?; + task.set("env", env_task)?; + + setup_environment(&mut ssh, &host, &task)?; + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 33f18c4..01e2acd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,29 +1,24 @@ mod args; mod defaults; +mod komando; +mod models; mod modules; -pub mod ssh; +mod report; +mod ssh; mod util; mod validator; +use anyhow::Result; use args::Args; use clap::Parser; use defaults::Defaults; -use mlua::{ - chunk, - Error::{self, RuntimeError}, - FromLua, Integer, IntoLua, Lua, LuaSerdeExt, MultiValue, Table, UserData, Value, -}; -use modules::{base_module, collect_modules}; -use rayon::prelude::*; +use komando::{komando, komando_parallel_hosts, komando_parallel_tasks}; +use mlua::{chunk, Lua, MultiValue}; +use modules::{base_module, collect_core_modules}; +use report::generate_report; use rustyline::DefaultEditor; -use serde::{Deserialize, Serialize}; -use ssh::{Elevation, ElevationMethod, SSHAuthMethod, SSHSession}; -use std::{collections::HashMap, env, fs, path::Path}; -use util::{ - dprint, filter_hosts, host_display, parse_hosts_json_file, parse_hosts_json_url, - regex_is_match, task_display, -}; -use validator::{validate_host, validate_task}; +use std::{env, fs, path::Path}; +use util::{dprint, filter_hosts, parse_hosts_json_file, parse_hosts_json_url, regex_is_match}; pub fn create_lua() -> mlua::Result { let lua = Lua::new(); @@ -62,10 +57,10 @@ pub fn create_lua() -> mlua::Result { pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { let komandan = lua.create_table()?; - let defaults = Defaults::global(); + let defaults = Defaults::global()?; komandan.set("defaults", defaults)?; - let base_module = base_module(&lua); + let base_module = base_module(lua)?; komandan.set("KomandanModule", base_module)?; komandan.set("komando", lua.create_function(komando)?)?; @@ -92,582 +87,14 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { komandan.set("dprint", lua.create_function(dprint)?)?; // Add core modules - komandan.set("modules", collect_modules(lua))?; + komandan.set("modules", collect_core_modules(lua)?)?; lua.globals().set("komandan", &komandan)?; Ok(()) } -fn get_user(host: &Table, task: &Table) -> mlua::Result { - let defaults = Defaults::global(); - let default_user = match defaults.user.read() { - Ok(user) => user, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - let user = match host.get::("user") { - Ok(user) => user, - Err(_) => match *default_user { - Some(ref user) => user.clone(), - None => match env::var("USER") { - Ok(user) => user, - Err(_) => { - return Err(RuntimeError(format!( - "No user specified for task '{}'.", - task_display(task) - ))) - } - }, - }, - }; - - Ok(user) -} - -fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthMethod)> { - let host_display = host_display(host); - let task_display = task_display(task); - - let user = get_user(host, task)?; - - let defaults = Defaults::global(); - - let default_private_key_file = match defaults.private_key_file.read() { - Ok(private_key_file) => private_key_file, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let default_private_key_pass = match defaults.private_key_pass.read() { - Ok(private_key_pass) => private_key_pass, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let default_password = match defaults.password.read() { - Ok(password) => password, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let ssh_auth_method = match host.get::("private_key_file") { - Ok(private_key_file) => SSHAuthMethod::PublicKey { - private_key: private_key_file, - passphrase: match host.get::("private_key_pass") { - Ok(passphrase) => Some(passphrase), - Err(_) => match *default_private_key_pass { - Some(ref private_key_pass) => Some(private_key_pass.clone()), - None => None, - }, - }, - }, - Err(_) => match *default_private_key_file { - Some(ref private_key_file) => SSHAuthMethod::PublicKey { - private_key: private_key_file.clone(), - passphrase: match host.get::("private_key_pass") { - Ok(passphrase) => Some(passphrase), - Err(_) => match *default_private_key_pass { - Some(ref passphrase) => Some(passphrase.clone()), - None => None, - }, - }, - }, - None => match host.get::("password") { - Ok(password) => SSHAuthMethod::Password(password), - Err(_) => match *default_password { - Some(ref password) => SSHAuthMethod::Password(password.clone()), - None => { - return Err(RuntimeError(format!( - "No authentication method specified for task '{}' on host '{}'.", - task_display, host_display - ))) - } - }, - }, - }, - }; - - Ok((user, ssh_auth_method)) -} - -fn get_elevation_config(host: &Table, task: &Table) -> mlua::Result { - let defaults = Defaults::global(); - - let default_elevate = match defaults.elevate.read() { - Ok(elevate) => elevate, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let task_elevate = task.get::("elevate")?; - let host_elevate = host.get::("elevate")?; - - let elevate = if !task_elevate.is_nil() { - task_elevate.as_boolean().unwrap() - } else if !host_elevate.is_nil() { - host_elevate.as_boolean().unwrap() - } else { - *default_elevate - }; - - if !elevate { - return Ok(Elevation { - method: ElevationMethod::None, - as_user: None, - }); - } - - let default_elevation_method = match defaults.elevation_method.read() { - Ok(elevation_method) => elevation_method, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let elevation_method_str = task.get::("elevation_method").unwrap_or( - host.get::("elevation_method") - .unwrap_or(default_elevation_method.clone()), - ); - - let elevation_method = match elevation_method_str.as_str() { - "none" => Ok(ElevationMethod::None), - "sudo" => Ok(ElevationMethod::Sudo), - "su" => Ok(ElevationMethod::Su), - _ => Err(RuntimeError(format!( - "Unsupported elevation method: {}", - elevation_method_str - ))), - }; - - let default_as_user = match defaults.as_user.read() { - Ok(as_user) => as_user, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let as_user = task.get::>("as_user").unwrap_or( - host.get::>("as_user") - .unwrap_or(default_as_user.clone()), - ); - - Ok(Elevation { - method: elevation_method?, - as_user, - }) -} - -fn setup_ssh_session(host: &Table) -> mlua::Result { - let defaults = Defaults::global(); - let mut ssh = SSHSession::new()?; - - let default_host_key_check = match defaults.host_key_check.read() { - Ok(host_key_check) => host_key_check, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let host_key_check = match host.get::("host_key_check") { - Ok(host_key_check) => match host_key_check { - Value::Nil => default_host_key_check.clone(), - Value::Boolean(false) => false, - _ => true, - }, - Err(_) => true, - }; - - let default_known_hosts_file = match defaults.known_hosts_file.read() { - Ok(known_hosts_file) => known_hosts_file, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - if host_key_check { - ssh.known_hosts_file = match host.get::("known_hosts_file") { - Ok(known_hosts_file) => Some(known_hosts_file), - Err(_) => Some(default_known_hosts_file.clone()), - }; - } - - Ok(ssh) -} - -fn setup_environment(ssh: &mut SSHSession, host: &Table, task: &Table) -> mlua::Result<()> { - let defaults = Defaults::global(); - - let default_env = match defaults.env.read() { - Ok(env) => env, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let env_host = host.get::>("env")?; - let env_task = task.get::>("env")?; - - for (key, value) in default_env.clone() { - ssh.set_env(&key, &value); - } - - if !env_host.is_none() { - for pair in env_host.unwrap().pairs() { - let (key, value): (String, String) = pair?; - ssh.set_env(&key, &value); - } - } - - if !env_task.is_none() { - for pair in env_task.unwrap().pairs() { - let (key, value): (String, String) = pair?; - ssh.set_env(&key, &value); - } - } - - Ok(()) -} - -fn execute_task( - lua: &Lua, - module: &Table, - ssh: SSHSession, - task_display: &str, - host_display: &str, -) -> mlua::Result
{ - lua.load(chunk! { - print(">> Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...") - $module.ssh = $ssh - $module:run() - - local results = $module.ssh:get_session_results() - komandan.dprint(results.stdout) - if results.exit_code ~= 0 then - print(">> Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. results.exit_code .. ": " .. results.stderr) - else - print(">> Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded.") - end - - if $module.cleanup ~= nil then - $module:cleanup() - end - - return results - }) - .set_name("execute_task") - .eval::
() -} - -fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ - let host = lua.create_function(validate_host)?.call::
(&host)?; - let task = lua.create_function(validate_task)?.call::
(&task)?; - let module = task.get::
(1)?; - - let host_display = host_display(&host); - let task_display = task_display(&task); - - let defaults = Defaults::global(); - - let (user, ssh_auth_method) = get_auth_config(&host, &task)?; - let elevation = get_elevation_config(&host, &task)?; - - let default_port = match defaults.port.read() { - Ok(port) => port, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let port = host - .get::("port") - .unwrap_or(default_port.clone() as i64) as u16; - - let mut ssh = setup_ssh_session(&host)?; - ssh.elevation = elevation; - - ssh.connect( - host.get::("address")?.as_str(), - port, - &user, - ssh_auth_method, - )?; - - setup_environment(&mut ssh, &host, &task)?; - - let results = execute_task(lua, &module, ssh, &task_display, &host_display)?; - - let default_ignore_exit_code = match defaults.ignore_exit_code.read() { - Ok(ignore_exit_code) => ignore_exit_code, - Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), - }; - - let ignore_exit_code = task - .get::("ignore_exit_code") - .unwrap_or(default_ignore_exit_code.clone()); - - if results.get::("exit_code").unwrap() != 0 && !ignore_exit_code { - return Err(RuntimeError("Failed to run task.".to_string())); - } - - Ok(results) -} - -#[derive(Clone, Debug)] -struct Host { - name: Option, - address: String, - port: Option, - user: Option, - private_key_file: Option, - private_key_pass: Option, - password: Option, - elevate: Option, - elevation_method: Option, - as_user: Option, - env: Option>, -} - -impl FromLua for Host { - fn from_lua(lua_value: Value, _: &Lua) -> mlua::Result { - let table = lua_value.as_table().unwrap(); - Ok(Host { - name: table.get("name")?, - address: table.get("address")?, - port: table.get("port")?, - user: table.get("user")?, - private_key_file: table.get("private_key_file")?, - private_key_pass: table.get("private_key_pass")?, - password: table.get("password")?, - elevate: table.get("elevate")?, - elevation_method: match table.get::("elevation_method") { - Ok(elevation_method) => match elevation_method.as_str() { - "none" => Some(ElevationMethod::None), - "sudo" => Some(ElevationMethod::Sudo), - "su" => Some(ElevationMethod::Su), - _ => None, - }, - Err(_) => None, - }, - as_user: table.get("as_user")?, - env: table.get("env")?, - }) - } -} - -impl IntoLua for Host { - fn into_lua(self, lua: &Lua) -> mlua::Result { - let table = lua.create_table()?; - if self.name.is_some() { - table.set("name", self.name.unwrap())?; - } - table.set("address", self.address)?; - if self.port.is_some() { - table.set("port", self.port.unwrap())?; - } - if self.user.is_some() { - table.set("user", self.user.unwrap())?; - } - if self.private_key_file.is_some() { - table.set("private_key_file", self.private_key_file.unwrap())?; - } - if self.private_key_pass.is_some() { - table.set("private_key_pass", self.private_key_pass.unwrap())?; - } - if self.password.is_some() { - table.set("password", self.password.unwrap())?; - } - if self.elevate.is_some() { - table.set("elevate", self.elevate.unwrap())?; - } - if self.elevation_method.is_some() { - match self.elevation_method.unwrap() { - ElevationMethod::None => table.set("elevation_method", "none")?, - ElevationMethod::Sudo => table.set("elevation_method", "sudo")?, - ElevationMethod::Su => table.set("elevation_method", "su")?, - } - } - if self.as_user.is_some() { - table.set("as_user", self.as_user.unwrap())?; - } - if self.env.is_some() { - table.set("env", self.env.unwrap())?; - } - Ok(Value::Table(table)) - } -} - -#[derive(Clone, Debug)] -struct Task { - name: Option, - module: Module, - ignore_exit_code: Option, - elevate: Option, - elevation_method: Option, - as_user: Option, - env: Option>, -} - -impl FromLua for Task { - fn from_lua(lua_value: Value, lua: &Lua) -> mlua::Result { - let table = lua_value.as_table().unwrap(); - Ok(Task { - name: table.get("name")?, - module: Module::from_lua(table.get(1)?, lua)?, - ignore_exit_code: table.get("ignore_exit_code")?, - elevate: table.get("elevate")?, - elevation_method: match table.get::("elevation_method") { - Ok(elevation_method) => match elevation_method.as_str() { - "none" => Some(ElevationMethod::None), - "sudo" => Some(ElevationMethod::Sudo), - "su" => Some(ElevationMethod::Su), - _ => None, - }, - Err(_) => None, - }, - as_user: table.get("as_user")?, - env: table.get("env")?, - }) - } -} - -impl IntoLua for Task { - fn into_lua(self, lua: &Lua) -> mlua::Result { - let table = lua.create_table()?; - if self.name.is_some() { - table.set("name", self.name.unwrap())?; - } - table.set(1, self.module.into_lua(lua)?)?; - if self.ignore_exit_code.is_some() { - table.set("ignore_exit_code", self.ignore_exit_code.unwrap())?; - } - if self.elevate.is_some() { - table.set("elevate", self.elevate.unwrap())?; - } - if self.elevation_method.is_some() { - match self.elevation_method.unwrap() { - ElevationMethod::None => table.set("elevation_method", "none")?, - ElevationMethod::Sudo => table.set("elevation_method", "sudo")?, - ElevationMethod::Su => table.set("elevation_method", "su")?, - } - } - if self.as_user.is_some() { - table.set("as_user", self.as_user.unwrap())?; - } - if self.env.is_some() { - table.set("env", self.env.unwrap())?; - } - Ok(Value::Table(table)) - } -} - -#[derive(Clone, Debug)] -struct Module { - functions: HashMap>, - others: HashMap, -} - -impl FromLua for Module { - fn from_lua(value: Value, _: &Lua) -> mlua::Result { - let table = value.as_table().unwrap(); - let mut functions: HashMap> = HashMap::new(); - let mut others: HashMap = HashMap::new(); - for pair in table.pairs::() { - let (key, value) = pair.unwrap(); - if value.is_function() { - functions.insert(key.to_string()?, value.as_function().unwrap().dump(true)); - } else { - others.insert( - key.to_string()?, - serde_json::to_string(&value).map_err(Error::external)?, - ); - } - } - Ok(Module { functions, others }) - } -} - -impl IntoLua for Module { - fn into_lua(self, lua: &Lua) -> mlua::Result { - let table = lua.create_table()?; - self.functions.iter().for_each(|(key, value)| { - table - .set(key.as_str(), lua.load(value).into_function().unwrap()) - .unwrap(); - }); - self.others.iter().for_each(|(key, value)| { - let json: serde_json::Value = serde_json::from_str(value).unwrap(); - table - .set(key.as_str(), lua.to_value(&json).unwrap()) - .unwrap(); - }); - Ok(Value::Table(table)) - } -} - -#[derive(Serialize, Deserialize)] -struct KomandoResult { - stdout: String, - stderr: String, - exit_code: i32, -} - -impl UserData for KomandoResult {} - -fn komando_parallel_tasks(lua: &Lua, (host, tasks): (Value, Value)) -> mlua::Result
{ - let host = Host::from_lua(host, lua)?; - let mut tasks_hm = HashMap::::new(); - for pair in tasks.as_table().unwrap().pairs::() { - let (key, value): (u32, Value) = pair?; - let task = Task::from_lua(value, lua)?; - tasks_hm.insert(key, task); - } - - let results: HashMap = tasks_hm - .par_iter() - .map(|(i, task)| { - let lua = create_lua().unwrap(); - let host = host.clone().into_lua(&lua).unwrap(); - let task = task.clone().into_lua(&lua).unwrap(); - let result = komando(&lua, (host, task)).unwrap(); - - return ( - *i, - lua.from_value::(Value::Table(result)) - .unwrap(), - ); - }) - .collect::>(); - - let results_table = lua.create_table()?; - results.iter().for_each(|(i, result)| { - results_table - .set(*i, lua.to_value(result).unwrap()) - .unwrap(); - }); - - Ok(results_table) -} - -fn komando_parallel_hosts(lua: &Lua, (hosts, task): (Value, Value)) -> mlua::Result
{ - let task = Task::from_lua(task, lua)?; - let mut hosts_hm = HashMap::::new(); - for pair in hosts.as_table().unwrap().pairs::() { - let (key, value): (u32, Value) = pair?; - let host = Host::from_lua(value, lua)?; - hosts_hm.insert(key, host); - } - - let results: HashMap = hosts_hm - .par_iter() - .map(|(i, host)| { - let lua = create_lua().unwrap(); - let host = host.clone().into_lua(&lua).unwrap(); - let task = task.clone().into_lua(&lua).unwrap(); - let result = komando(&lua, (host, task)).unwrap(); - - return ( - *i, - lua.from_value::(Value::Table(result)) - .unwrap(), - ); - }) - .collect::>(); - - let results_table = lua.create_table()?; - results.iter().for_each(|(i, result)| { - results_table - .set(*i, lua.to_value(result).unwrap()) - .unwrap(); - }); - - Ok(results_table) -} - -pub fn run_main_file(lua: &Lua, main_file: &String) -> anyhow::Result<()> { +pub fn run_main_file(lua: &Lua, main_file: &String) -> Result<()> { let script = match fs::read_to_string(main_file) { Ok(script) => script, Err(e) => { @@ -681,6 +108,8 @@ pub fn run_main_file(lua: &Lua, main_file: &String) -> anyhow::Result<()> { lua.load(&script).set_name(main_file).exec()?; + generate_report(); + Ok(()) } @@ -715,7 +144,7 @@ pub fn repl(lua: &Lua) { incomplete_input: true, .. }) => { - line.push_str("\n"); + line.push('\n'); prompt = ">> "; } Err(e) => { @@ -736,6 +165,8 @@ pub fn print_version() { // Tests #[cfg(test)] mod tests { + use mlua::Table; + use super::*; #[test] @@ -746,6 +177,7 @@ mod tests { args, Args { chunk: None, + dry_run: false, interactive: false, verbose: true, version: false, @@ -755,190 +187,42 @@ mod tests { } #[test] - fn test_setup_komandan_table() { - let lua = create_lua().unwrap(); + fn test_setup_komandan_table() -> Result<()> { + let lua = create_lua()?; // Assert that the komandan table is set up correctly - let komandan_table = lua.globals().get::
("komandan").unwrap(); - assert!(komandan_table.contains_key("defaults").unwrap()); - assert!(komandan_table.contains_key("KomandanModule").unwrap()); - assert!(komandan_table.contains_key("komando").unwrap()); - assert!(komandan_table.contains_key("regex_is_match").unwrap()); - assert!(komandan_table.contains_key("filter_hosts").unwrap()); - assert!(komandan_table - .contains_key("parse_hosts_json_file") - .unwrap()); - assert!(komandan_table.contains_key("parse_hosts_json_url").unwrap()); - assert!(komandan_table.contains_key("dprint").unwrap()); - - let modules_table = komandan_table.get::
("modules").unwrap(); - assert!(modules_table.contains_key("apt").unwrap()); - assert!(modules_table.contains_key("cmd").unwrap()); - assert!(modules_table.contains_key("lineinfile").unwrap()); - assert!(modules_table.contains_key("script").unwrap()); - assert!(modules_table.contains_key("systemd_service").unwrap()); - assert!(modules_table.contains_key("template").unwrap()); - assert!(modules_table.contains_key("upload").unwrap()); - assert!(modules_table.contains_key("download").unwrap()); + let komandan_table = lua.globals().get::
("komandan")?; + assert!(komandan_table.contains_key("defaults")?); + assert!(komandan_table.contains_key("KomandanModule")?); + assert!(komandan_table.contains_key("komando")?); + assert!(komandan_table.contains_key("regex_is_match")?); + assert!(komandan_table.contains_key("filter_hosts")?); + assert!(komandan_table.contains_key("parse_hosts_json_file")?); + assert!(komandan_table.contains_key("parse_hosts_json_url")?); + assert!(komandan_table.contains_key("dprint")?); + + let modules_table = komandan_table.get::
("modules")?; + assert!(modules_table.contains_key("apt")?); + assert!(modules_table.contains_key("cmd")?); + assert!(modules_table.contains_key("lineinfile")?); + assert!(modules_table.contains_key("script")?); + assert!(modules_table.contains_key("systemd_service")?); + assert!(modules_table.contains_key("template")?); + assert!(modules_table.contains_key("upload")?); + assert!(modules_table.contains_key("download")?); + + Ok(()) } #[test] - fn test_run_main_file() { - let lua = create_lua().unwrap(); + fn test_run_main_file() -> Result<()> { + let lua = create_lua()?; // Test with a valid Lua file let main_file = "examples/hosts.lua".to_string(); let result = run_main_file(&lua, &main_file); assert!(result.is_ok()); - } - #[test] - fn test_get_auth_config() { - let lua = create_lua().unwrap(); - let host = lua.create_table().unwrap(); - - // Test with user in host - host.set("address", "localhost").unwrap(); - host.set("user", "testuser").unwrap(); - host.set("private_key_file", "/path/to/key").unwrap(); - - let module_params = lua.create_table().unwrap(); - module_params.set("cmd", "echo test").unwrap(); - let module = lua - .load(chunk! { - return komandan.modules.cmd($module_params) - }) - .eval::
() - .unwrap(); - let task = lua.create_table().unwrap(); - task.set(1, module).unwrap(); - - let (user, auth) = get_auth_config(&host, &task).unwrap(); - assert_eq!(user, "testuser"); - match auth { - SSHAuthMethod::PublicKey { - private_key, - passphrase, - } => { - assert_eq!(private_key, "/path/to/key"); - assert!(passphrase.is_none()); - } - _ => panic!("Expected PublicKey authentication"), - } - - // Test with password auth - host.set("private_key_file", Value::Nil).unwrap(); - host.set("password", "testpass").unwrap(); - let (_, auth) = get_auth_config(&host, &task).unwrap(); - match auth { - SSHAuthMethod::Password(pass) => assert_eq!(pass, "testpass"), - _ => panic!("Expected Password authentication"), - } - - // Test with no authentication method - host.set("password", Value::Nil).unwrap(); - let result = get_auth_config(&host, &task); - assert!(result.is_err()); - } - - #[test] - fn test_get_elevation_config() { - let lua = create_lua().unwrap(); - let host = lua.create_table().unwrap(); - let task = lua.create_table().unwrap(); - - // Test with no elevation - let elevation = get_elevation_config(&host, &task).unwrap(); - assert!(matches!( - elevation, - Elevation { - method: ElevationMethod::None, - as_user: None - } - )); - - // Test with elevation from task - task.set("elevate", true).unwrap(); - let elevation = get_elevation_config(&host, &task).unwrap(); - assert!(matches!( - elevation, - Elevation { - method: ElevationMethod::Sudo, - as_user: None - } - )); - - // Test with custom elevation method - task.set("elevation_method", "su").unwrap(); - let elevation = get_elevation_config(&host, &task).unwrap(); - assert!(matches!( - elevation, - Elevation { - method: ElevationMethod::Su, - as_user: None - } - )); - - // Test invalid elevation method - task.set("elevation_method", "invalid").unwrap(); - assert!(get_elevation_config(&host, &task).is_err()); - } - - #[test] - fn test_setup_ssh_session() { - let lua = create_lua().unwrap(); - let host = lua.create_table().unwrap(); - host.set("address", "localhost").unwrap(); - - // Test with default settings - let ssh = setup_ssh_session(&host).unwrap(); - assert!(ssh.known_hosts_file.is_some()); - - // Test with host key check disabled - host.set("host_key_check", false).unwrap(); - let ssh = setup_ssh_session(&host).unwrap(); - assert!(ssh.known_hosts_file.is_none()); - - // Test with custom known_hosts file - host.set("known_hosts_file", "/path/to/known_hosts") - .unwrap(); - host.set("host_key_check", true).unwrap(); - let ssh = setup_ssh_session(&host).unwrap(); - assert_eq!(ssh.known_hosts_file.unwrap(), "/path/to/known_hosts"); - - // Test with known_hosts from defaults - host.set("known_hosts_file", Value::Nil).unwrap(); - lua.load(chunk! { - komandan.defaults:set_known_hosts_file("/default/known_hosts") - }) - .exec() - .unwrap(); - let ssh = setup_ssh_session(&host).unwrap(); - assert_eq!(ssh.known_hosts_file.unwrap(), "/default/known_hosts"); - } - - #[test] - fn test_setup_environment() { - let lua = create_lua().unwrap(); - let mut ssh = SSHSession::new().unwrap(); - let defaults = lua.create_table().unwrap(); - let host = lua.create_table().unwrap(); - let task = lua.create_table().unwrap(); - - // Test with environment variables at all levels - let env_defaults = lua.create_table().unwrap(); - env_defaults.set("DEFAULT_VAR", "default_value").unwrap(); - defaults.set("env", env_defaults).unwrap(); - - let env_host = lua.create_table().unwrap(); - env_host.set("HOST_VAR", "host_value").unwrap(); - env_host.set("DEFAULT_VAR", "overridden_value").unwrap(); // Override default - host.set("env", env_host).unwrap(); - - let env_task = lua.create_table().unwrap(); - env_task.set("TASK_VAR", "task_value").unwrap(); - task.set("env", env_task).unwrap(); - - setup_environment(&mut ssh, &host, &task).unwrap(); + Ok(()) } } diff --git a/src/main.rs b/src/main.rs index dfbf304..ebec3a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ mod args; +use anyhow::Result; use args::Args; use clap::Parser; use komandan::{create_lua, print_version, repl, run_main_file}; -fn main() -> anyhow::Result<()> { +fn main() -> Result<()> { let args = Args::parse(); if args.version { @@ -18,6 +19,10 @@ fn main() -> anyhow::Result<()> { lua.load(&chunk).eval::<()>()?; } + if args.dry_run { + println!("[[[ Running in dry-run mode ]]]"); + } + if let Some(main_file) = &args.main_file { run_main_file(&lua, main_file)?; } else if args.chunk.is_none() { diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..7457f42 --- /dev/null +++ b/src/models.rs @@ -0,0 +1,343 @@ +use std::collections::HashMap; + +use mlua::{Error, FromLua, IntoLua, Lua, LuaSerdeExt, UserData, Value}; +use serde::{Deserialize, Serialize}; + +use crate::ssh::ElevationMethod; + +#[derive(Clone, Debug)] +pub struct Host { + name: Option, + address: String, + port: Option, + user: Option, + private_key_file: Option, + private_key_pass: Option, + password: Option, + elevate: Option, + elevation_method: Option, + as_user: Option, + env: Option>, +} + +impl FromLua for Host { + fn from_lua(lua_value: Value, _: &Lua) -> mlua::Result { + let table = lua_value.as_table().unwrap(); + Ok(Host { + name: table.get("name")?, + address: table.get("address")?, + port: table.get("port")?, + user: table.get("user")?, + private_key_file: table.get("private_key_file")?, + private_key_pass: table.get("private_key_pass")?, + password: table.get("password")?, + elevate: table.get("elevate")?, + elevation_method: match table.get::("elevation_method") { + Ok(elevation_method) => match elevation_method.as_str() { + "none" => Some(ElevationMethod::None), + "sudo" => Some(ElevationMethod::Sudo), + "su" => Some(ElevationMethod::Su), + _ => None, + }, + Err(_) => None, + }, + as_user: table.get("as_user")?, + env: table.get("env")?, + }) + } +} + +impl IntoLua for Host { + fn into_lua(self, lua: &Lua) -> mlua::Result { + let table = lua.create_table()?; + if self.name.is_some() { + table.set("name", self.name.unwrap())?; + } + table.set("address", self.address)?; + if self.port.is_some() { + table.set("port", self.port.unwrap())?; + } + if self.user.is_some() { + table.set("user", self.user.unwrap())?; + } + if self.private_key_file.is_some() { + table.set("private_key_file", self.private_key_file.unwrap())?; + } + if self.private_key_pass.is_some() { + table.set("private_key_pass", self.private_key_pass.unwrap())?; + } + if self.password.is_some() { + table.set("password", self.password.unwrap())?; + } + if self.elevate.is_some() { + table.set("elevate", self.elevate.unwrap())?; + } + if self.elevation_method.is_some() { + match self.elevation_method.unwrap() { + ElevationMethod::None => table.set("elevation_method", "none")?, + ElevationMethod::Sudo => table.set("elevation_method", "sudo")?, + ElevationMethod::Su => table.set("elevation_method", "su")?, + } + } + if self.as_user.is_some() { + table.set("as_user", self.as_user.unwrap())?; + } + if self.env.is_some() { + table.set("env", self.env.unwrap())?; + } + Ok(Value::Table(table)) + } +} + +#[derive(Clone, Debug)] +pub struct Task { + name: Option, + module: Module, + ignore_exit_code: Option, + elevate: Option, + elevation_method: Option, + as_user: Option, + env: Option>, +} + +impl FromLua for Task { + fn from_lua(lua_value: Value, lua: &Lua) -> mlua::Result { + let table = lua_value.as_table().unwrap(); + Ok(Task { + name: table.get("name")?, + module: Module::from_lua(table.get(1)?, lua)?, + ignore_exit_code: table.get("ignore_exit_code")?, + elevate: table.get("elevate")?, + elevation_method: match table.get::("elevation_method") { + Ok(elevation_method) => match elevation_method.as_str() { + "none" => Some(ElevationMethod::None), + "sudo" => Some(ElevationMethod::Sudo), + "su" => Some(ElevationMethod::Su), + _ => None, + }, + Err(_) => None, + }, + as_user: table.get("as_user")?, + env: table.get("env")?, + }) + } +} + +impl IntoLua for Task { + fn into_lua(self, lua: &Lua) -> mlua::Result { + let table = lua.create_table()?; + if self.name.is_some() { + table.set("name", self.name.unwrap())?; + } + table.set(1, self.module.into_lua(lua)?)?; + if self.ignore_exit_code.is_some() { + table.set("ignore_exit_code", self.ignore_exit_code.unwrap())?; + } + if self.elevate.is_some() { + table.set("elevate", self.elevate.unwrap())?; + } + if self.elevation_method.is_some() { + match self.elevation_method.unwrap() { + ElevationMethod::None => table.set("elevation_method", "none")?, + ElevationMethod::Sudo => table.set("elevation_method", "sudo")?, + ElevationMethod::Su => table.set("elevation_method", "su")?, + } + } + if self.as_user.is_some() { + table.set("as_user", self.as_user.unwrap())?; + } + if self.env.is_some() { + table.set("env", self.env.unwrap())?; + } + Ok(Value::Table(table)) + } +} + +#[derive(Clone, Debug)] +pub struct Module { + functions: HashMap>, + others: HashMap, +} + +impl FromLua for Module { + fn from_lua(value: Value, _: &Lua) -> mlua::Result { + let table = value.as_table().unwrap(); + let mut functions: HashMap> = HashMap::new(); + let mut others: HashMap = HashMap::new(); + for pair in table.pairs::() { + let (key, value) = pair.unwrap(); + if value.is_function() { + functions.insert(key.to_string()?, value.as_function().unwrap().dump(true)); + } else { + others.insert( + key.to_string()?, + serde_json::to_string(&value).map_err(Error::external)?, + ); + } + } + Ok(Module { functions, others }) + } +} + +impl IntoLua for Module { + fn into_lua(self, lua: &Lua) -> mlua::Result { + let table = lua.create_table()?; + self.functions.iter().for_each(|(key, value)| { + table + .set(key.as_str(), lua.load(value).into_function().unwrap()) + .unwrap(); + }); + self.others.iter().for_each(|(key, value)| { + let json: serde_json::Value = serde_json::from_str(value).unwrap(); + table + .set(key.as_str(), lua.to_value(&json).unwrap()) + .unwrap(); + }); + Ok(Value::Table(table)) + } +} + +#[derive(Serialize, Deserialize)] +pub struct KomandoResult { + stdout: String, + stderr: String, + exit_code: i32, + changed: bool, +} + +impl UserData for KomandoResult {} + +#[cfg(test)] +mod tests { + use super::*; + use mlua::Lua; + use std::collections::HashMap; + + #[test] + fn test_host_from_lua() { + let lua = Lua::new(); + let table = lua.create_table().unwrap(); + table.set("address", "127.0.0.1").unwrap(); + let host = Host::from_lua(Value::Table(table.clone()), &lua).unwrap(); + assert_eq!(host.address, "127.0.0.1"); + assert_eq!(host.name, None); + + table.set("name", "test").unwrap(); + table.set("port", 22).unwrap(); + table.set("user", "user").unwrap(); + table.set("private_key_file", "/path/to/key").unwrap(); + table.set("private_key_pass", "pass").unwrap(); + table.set("password", "password").unwrap(); + table.set("elevate", true).unwrap(); + table.set("elevation_method", "sudo").unwrap(); + table.set("as_user", "root").unwrap(); + let mut env = HashMap::new(); + env.insert("key".to_string(), "value".to_string()); + table.set("env", env.clone()).unwrap(); + + let host = Host::from_lua(Value::Table(table), &lua).unwrap(); + assert_eq!(host.address, "127.0.0.1"); + assert_eq!(host.name, Some("test".to_string())); + assert_eq!(host.port, Some(22)); + assert_eq!(host.user, Some("user".to_string())); + assert_eq!(host.private_key_file, Some("/path/to/key".to_string())); + assert_eq!(host.private_key_pass, Some("pass".to_string())); + assert_eq!(host.password, Some("password".to_string())); + assert_eq!(host.elevate, Some(true)); + assert_eq!(host.elevation_method, Some(ElevationMethod::Sudo)); + assert_eq!(host.as_user, Some("root".to_string())); + assert_eq!(host.env, Some(env)); + } + + #[test] + fn test_host_into_lua() { + let lua = Lua::new(); + let mut env = HashMap::new(); + env.insert("key".to_string(), "value".to_string()); + let host = Host { + name: Some("test".to_string()), + address: "127.0.0.1".to_string(), + port: Some(22), + user: Some("user".to_string()), + private_key_file: Some("/path/to/key".to_string()), + private_key_pass: Some("pass".to_string()), + password: Some("password".to_string()), + elevate: Some(true), + elevation_method: Some(ElevationMethod::Sudo), + as_user: Some("root".to_string()), + env: Some(env.clone()), + }; + + let table = host.into_lua(&lua).unwrap().as_table().unwrap().clone(); + assert_eq!(table.get::("address").unwrap(), "127.0.0.1"); + assert_eq!(table.get::("name").unwrap(), "test"); + assert_eq!(table.get::("port").unwrap(), 22); + assert_eq!(table.get::("user").unwrap(), "user"); + assert_eq!( + table.get::("private_key_file").unwrap(), + "/path/to/key" + ); + assert_eq!(table.get::("private_key_pass").unwrap(), "pass"); + assert_eq!(table.get::("password").unwrap(), "password"); + assert_eq!(table.get::("elevate").unwrap(), true); + assert_eq!(table.get::("elevation_method").unwrap(), "sudo"); + assert_eq!(table.get::("as_user").unwrap(), "root"); + assert_eq!(table.get::>("env").unwrap(), env); + } + + #[test] + fn test_task_from_lua() { + let lua = Lua::new(); + let table = lua.create_table().unwrap(); + let module_table = lua.create_table().unwrap(); + module_table.set("command", "echo 'hello'").unwrap(); + table.set(1, module_table).unwrap(); + let task = Task::from_lua(Value::Table(table.clone()), &lua).unwrap(); + assert_eq!(task.name, None); + assert!(task.module.functions.is_empty()); + assert_eq!( + task.module.others.get("command").unwrap().clone(), + "\"echo 'hello'\"" + ); + + table.set("name", "test").unwrap(); + table.set("ignore_exit_code", true).unwrap(); + table.set("elevate", true).unwrap(); + table.set("elevation_method", "sudo").unwrap(); + table.set("as_user", "root").unwrap(); + let mut env = HashMap::new(); + env.insert("key".to_string(), "value".to_string()); + table.set("env", env.clone()).unwrap(); + + let task = Task::from_lua(Value::Table(table), &lua).unwrap(); + assert_eq!(task.name, Some("test".to_string())); + assert_eq!(task.ignore_exit_code, Some(true)); + assert_eq!(task.elevate, Some(true)); + assert_eq!(task.elevation_method, Some(ElevationMethod::Sudo)); + assert_eq!(task.as_user, Some("root".to_string())); + assert_eq!(task.env, Some(env)); + } + + #[test] + fn test_module_from_lua() { + let lua = Lua::new(); + let table = lua.create_table().unwrap(); + table.set("command", "echo 'hello'").unwrap(); + let module = Module::from_lua(Value::Table(table.clone()), &lua).unwrap(); + assert!(module.functions.is_empty()); + assert_eq!( + module.others.get("command").unwrap().clone(), + "\"echo 'hello'\"" + ); + + let function = lua.load("return 1").into_function().unwrap(); + table.set("test_func", function).unwrap(); + + let module = Module::from_lua(Value::Table(table), &lua).unwrap(); + assert_eq!(module.functions.len(), 1); + assert_eq!( + module.others.get("command").unwrap().clone(), + "\"echo 'hello'\"" + ); + } +} diff --git a/src/modules/apt.rs b/src/modules/apt.rs index 6da91db..beeb0e9 100644 --- a/src/modules/apt.rs +++ b/src/modules/apt.rs @@ -1,49 +1,131 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn apt(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.update_cache == nil then params.update_cache = false end - if params.package == nil and params.update_cache == false then + local valid_actions = { + install = true, + remove = true, + purge = true, + upgrade = true, + autoremove = true + } + + if params.action ~= nil and not valid_actions[params.action] then + error("Invalid action: " .. params.action .. ". Valid actions are: install, remove, purge, upgrade, autoremove.") + end + + if (params.action == "install" or params.action == "remove" or params.action == "purge") and params.package == nil then error("package is required") end + if params.package ~= nil and params.action == nil then + params.action = "install" + end + if params.install_recommends == nil then params.install_recommends = true end - if params.action == nil then - params.action = "install" + params.install_opts = "" + if not params.install_recommends then + params.install_opts = params.install_opts .. " --no-install-recommends" + end + + local function sanitize(input) + if type(input) ~= "string" then + return nil -- Ensure input is a string + end + return input:gsub("[^%w%-_]", "") end + params.package = sanitize(params.package) + params.install_opts = sanitize(params.install_opts) + local module = $base_module:new({ name = "apt" }) module.params = $params + module.update_cache = function(self) + local update_result = self.ssh:cmd("apt update") + if update_result.exit_code == 0 and not update_result.stdout:match("Get:") then + self.ssh:set_changed(false) + end + end + + module.is_installed = function(self) + if self.params.package == nil then + return false + end + + local pkg_check = self.ssh:cmdq("dpkg-query -W -f='${Status}' " .. self.params.package .. " 2>/dev/null | grep -q 'ok installed'") + if pkg_check.exit_code == 0 then + return true + else + return false + end + end + + module.dry_run = function(self) + if self.params.update_cache then + self:update_cache() + end + + local installed = self:is_installed() + + if self.params.action == "install" then + self.ssh:cmd("apt -s install " .. self.params.package .. self.params.install_opts) + if installed then + self.ssh:set_changed(false) + end + elseif self.params.action == "remove" then + self.ssh:cmd("apt -s remove " .. self.params.package) + if not installed then + self.ssh:set_changed(false) + end + elseif self.params.action == "purge" then + self.ssh:cmd("apt -s purge " .. self.params.package) + if not installed then + self.ssh:set_changed(false) + end + elseif self.params.action == "upgrade" then + self.ssh:cmd("apt -s upgrade") + elseif self.params.action == "autoremove" then + self.ssh:cmd("apt -s autoremove") + end + end + module.run = function(self) if self.params.update_cache then - self.ssh:cmd("apt update") + self:update_cache() end - if self.params.package == nill then + if self.params.package == nil then return end - local install_opts = "" - if not self.params.install_recommends then - install_opts = install_opts .. " --no-install-recommends" - end + local installed = self:is_installed() if self.params.action == "install" then - self.ssh:cmd("apt install -y " .. self.params.package .. install_opts) + self.ssh:cmd("apt install -y " .. self.params.package .. self.params.install_opts) + if installed then + self.ssh:set_changed(false) + end elseif self.params.action == "remove" then self.ssh:cmd("apt remove -y " .. self.params.package) + if not installed then + self.ssh:set_changed(false) + end elseif self.params.action == "purge" then self.ssh:cmd("apt purge -y " .. self.params.package) + if not installed then + self.ssh:set_changed(false) + end elseif self.params.action == "upgrade" then self.ssh:cmd("apt upgrade -y") elseif self.params.action == "autoremove" then @@ -68,9 +150,10 @@ mod tests { use super::*; #[test] - fn test_package_required() { + fn test_apt_package_required() { let lua = create_lua().unwrap(); let params = lua.create_table().unwrap(); + params.set("action", "install").unwrap(); let result = apt(&lua, params); assert!(result.is_err()); assert!(result @@ -80,7 +163,7 @@ mod tests { } #[test] - fn test_valid_package() { + fn test_apt_valid_package() { let lua = create_lua().unwrap(); let params = lua.create_table().unwrap(); params.set("package", "vim").unwrap(); diff --git a/src/modules/base.rs b/src/modules/base.rs new file mode 100644 index 0000000..f86d982 --- /dev/null +++ b/src/modules/base.rs @@ -0,0 +1,22 @@ +use mlua::{chunk, Table}; + +pub fn base_module(lua: &mlua::Lua) -> mlua::Result
{ + lua.load(chunk! { + local KomandanModule = {} + + KomandanModule.new = function(self,data) + local o = setmetatable({}, { __index = self }) + o.name = data.name + return o + end + + KomandanModule.run = function(self) + end + + KomandanModule.cleanup = function(self) + end + + return KomandanModule + }) + .eval::
() +} diff --git a/src/modules/cmd.rs b/src/modules/cmd.rs index 68cfc4a..86386dd 100644 --- a/src/modules/cmd.rs +++ b/src/modules/cmd.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn cmd(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "cmd" }) diff --git a/src/modules/core.rs b/src/modules/core.rs new file mode 100644 index 0000000..96abbdb --- /dev/null +++ b/src/modules/core.rs @@ -0,0 +1,23 @@ +use mlua::Table; + +use super::*; + +pub fn collect_core_modules(lua: &mlua::Lua) -> mlua::Result
{ + let modules = lua.create_table()?; + modules.set("apt", lua.create_function(apt::apt)?)?; + modules.set("cmd", lua.create_function(cmd::cmd)?)?; + modules.set("download", lua.create_function(download::download)?)?; + modules.set("lineinfile", lua.create_function(lineinfile::lineinfile)?)?; + modules.set( + "postgresql_user", + lua.create_function(postgresql_user::postgresql_user)?, + )?; + modules.set("script", lua.create_function(script::script)?)?; + modules.set( + "systemd_service", + lua.create_function(systemd_service::systemd_service)?, + )?; + modules.set("template", lua.create_function(template::template)?)?; + modules.set("upload", lua.create_function(upload::upload)?)?; + Ok(modules) +} diff --git a/src/modules/download.rs b/src/modules/download.rs index 0873301..8bf954d 100644 --- a/src/modules/download.rs +++ b/src/modules/download.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn download(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "download" }) diff --git a/src/modules/lineinfile.rs b/src/modules/lineinfile.rs index b14bdf5..5db5e30 100644 --- a/src/modules/lineinfile.rs +++ b/src/modules/lineinfile.rs @@ -8,7 +8,7 @@ pub fn lineinfile(lua: &Lua, params: Table) -> mlua::Result
{ .take(10) .collect(); - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.path == nil then @@ -41,9 +41,9 @@ pub fn lineinfile(lua: &Lua, params: Table) -> mlua::Result
{ module.random_file_name = $random_file_name module.lineinfile_script = $LINEINFILE_SCRIPT - module.run = function(self) + module.run_lineinfile_script = function(self) local tmpdir = self.ssh:get_tmpdir() - self.remote_script = tmpdir .. "/." .. self.random_file_name + self.remote_script = tmpdir .. "/." .. self.random_file_name self.ssh:write_remote_file(self.remote_script, self.lineinfile_script) self.ssh:chmod(self.remote_script, "+x") @@ -64,7 +64,26 @@ pub fn lineinfile(lua: &Lua, params: Table) -> mlua::Result
{ cmd = cmd .. " --insert_before \"" .. self.params.insert_before .. "\"" end - self.ssh:cmd(cmd) + if self.params.dry_run then + cmd = cmd .. " --dry-run" + end + + return self.ssh:cmd(cmd) + end + + module.dry_run = function(self) + self.params.dry_run = true + local result = self:run_lineinfile_script() + if result.stdout == "OK" then + self.ssh:set_changed(false) + end + end + + module.run = function(self) + local result = self:run_lineinfile_script() + if result.stdout == "OK" then + self.ssh:set_changed(false) + end end module.cleanup = function(self) @@ -86,6 +105,7 @@ const LINEINFILE_SCRIPT: &str = r#"#!/bin/sh STATE="present" CREATE="false" BACKUP="false" +DRYRUN="false" # Parse command-line arguments while [ $# -gt 0 ]; do @@ -122,6 +142,10 @@ while [ $# -gt 0 ]; do BACKUP="$2" shift 2 ;; + --dry-run) + DRYRUN="true" + shift 1 + ;; *) echo "Unknown option: $1" exit 1 @@ -138,8 +162,12 @@ fi # Create the file if it doesn't exist and --create is true if [ ! -f "$FILE_PATH" ]; then if [ "$CREATE" = "true" ]; then - touch "$FILE_PATH" - echo "File created: $FILE_PATH" + if [ "$DRYRUN" = "true" ]; then + echo "[DRY-RUN] File would be created: $FILE_PATH" + else + touch "$FILE_PATH" + echo "Changed" + fi else echo "Error: File '$FILE_PATH' does not exist and '--create' is set to false" exit 1 @@ -149,8 +177,12 @@ fi # Create a backup if requested if [ "$BACKUP" = "true" ]; then BACKUP_FILE="$FILE_PATH.$(date +%Y%m%d%H%M%S).bak" - cp "$FILE_PATH" "$BACKUP_FILE" - echo "Backup created: $BACKUP_FILE" + if [ "$DRYRUN" = "true" ]; then + echo "[DRY-RUN] Backup would be created: $BACKUP_FILE" + else + cp "$FILE_PATH" "$BACKUP_FILE" + echo "Changed" + fi fi # Handle the 'present' state @@ -162,64 +194,60 @@ if [ "$STATE" = "present" ]; then # Check if the line already exists if grep -Fxq "$LINE" "$FILE_PATH"; then - echo "Line already exists, no changes made." + echo "OK" # Unchanged exit 0 fi # Handle pattern replacement if [ -n "$REGEXP" ]; then if grep -q "$REGEXP" "$FILE_PATH"; then - sed -i.bak "/$REGEXP/c\$LINE" "$FILE_PATH" - echo "Line replaced matching pattern: $REGEXP" + if [ "$DRYRUN" = "true" ]; then + echo "[DRY-RUN] Line matching '$REGEXP' would be replaced with: $LINE" + else + sed -i "/$REGEXP/c\\$LINE" "$FILE_PATH" + echo "Changed" + fi exit 0 fi fi # Handle line insertion if [ -n "$INSERTAFTER" ]; then - if [ "$INSERTAFTER" = "EOF" ]; then - echo "$LINE" >> "$FILE_PATH" - echo "Line appended to the end of the file." + if [ "$DRYRUN" = "true" ]; then + echo "[DRY-RUN] Line '$LINE' would be inserted after pattern: $INSERTAFTER" else - sed -i.bak "/$INSERTAFTER/a\$LINE" "$FILE_PATH" - echo "Line inserted after pattern: $INSERTAFTER" + if [ "$INSERTAFTER" = "EOF" ]; then + echo "$LINE" >> "$FILE_PATH" + echo "Changed" + else + sed -i "/$INSERTAFTER/a\\$LINE" "$FILE_PATH" + echo "Changed" + fi fi elif [ -n "$INSERTBEFORE" ]; then - if [ "$INSERTBEFORE" = "BOF" ]; then - sed -i.bak "1i\$LINE" "$FILE_PATH" - echo "Line inserted at the beginning of the file." + if [ "$DRYRUN" = "true" ]; then + echo "[DRY-RUN] Line '$LINE' would be inserted before pattern: $INSERTBEFORE" else - sed -i.bak "/$INSERTBEFORE/i\$LINE" "$FILE_PATH" - echo "Line inserted before pattern: $INSERTBEFORE" + if [ "$INSERTBEFORE" = "BOF" ]; then + sed -i "1i\\$LINE" "$FILE_PATH" + echo "Changed" + else + sed -i "/$INSERTBEFORE/i\\$LINE" "$FILE_PATH" + echo "Changed" + fi fi else - echo "$LINE" >> "$FILE_PATH" - echo "Line appended to the file." - fi - exit 0 -fi - -# Handle the 'absent' state -if [ "$STATE" = "absent" ]; then - if [ -z "$REGEXP" ] && [ -z "$LINE" ]; then - echo "Error: '--pattern' or '--line' is required for 'absent' state" - exit 1 - fi - - # Remove lines matching the exact line - if [ -n "$LINE" ]; then - sed -i.bak "/^$(echo "$LINE" | sed 's/[^^]/[&]/g; s/\^/\\^/g')$/d" "$FILE_PATH" - echo "Removed line: $LINE" - fi - - # Remove lines matching the regex - if [ -n "$REGEXP" ]; then - sed -i.bak "/$REGEXP/d" "$FILE_PATH" - echo "Removed lines matching pattern: $REGEXP" + if [ "$DRYRUN" = "true" ]; then + echo "[DRY-RUN] Line '$LINE' would be appended to the file." + else + echo "$LINE" >> "$FILE_PATH" + echo "Changed" + fi fi exit 0 fi +# Handle 'absent' state if implemented in the future # If no valid state is provided echo "Error: Invalid state '$STATE'. Use 'present' or 'absent'." exit 1 diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 6e766d7..4d3b78a 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -1,69 +1,14 @@ mod apt; +mod base; mod cmd; +mod core; mod download; mod lineinfile; +mod postgresql_user; mod script; mod systemd_service; mod template; mod upload; -use mlua::{chunk, Table}; - -pub fn base_module(lua: &mlua::Lua) -> Table { - return lua - .load(chunk! { - local KomandanModule = {} - - KomandanModule.new = function(self,data) - local o = setmetatable({}, { __index = self }) - o.name = data.name - return o - end - - KomandanModule.run = function(self) - end - - KomandanModule.cleanup = function(self) - end - - return KomandanModule - }) - .eval::
() - .unwrap(); -} - -pub fn collect_modules(lua: &mlua::Lua) -> Table { - let modules = lua.create_table().unwrap(); - modules - .set("apt", lua.create_function(apt::apt).unwrap()) - .unwrap(); - modules - .set("cmd", lua.create_function(cmd::cmd).unwrap()) - .unwrap(); - modules - .set("download", lua.create_function(download::download).unwrap()) - .unwrap(); - modules - .set( - "lineinfile", - lua.create_function(lineinfile::lineinfile).unwrap(), - ) - .unwrap(); - modules - .set("script", lua.create_function(script::script).unwrap()) - .unwrap(); - modules - .set( - "systemd_service", - lua.create_function(systemd_service::systemd_service) - .unwrap(), - ) - .unwrap(); - modules - .set("template", lua.create_function(template::template).unwrap()) - .unwrap(); - modules - .set("upload", lua.create_function(upload::upload).unwrap()) - .unwrap(); - return modules; -} +pub use base::*; +pub use core::*; diff --git a/src/modules/postgresql_user.rs b/src/modules/postgresql_user.rs new file mode 100644 index 0000000..e3e3ca9 --- /dev/null +++ b/src/modules/postgresql_user.rs @@ -0,0 +1,143 @@ +use mlua::{chunk, ExternalResult, Lua, Table}; + +pub fn postgresql_user(lua: &Lua, params: Table) -> mlua::Result
{ + let base_module = super::base_module(lua)?; + let module = lua + .load(chunk! { + if params.name == nil then + error("'name' parameter is required") + end + + local valid_actions = { + create = true, + drop = true, + } + + if params.action ~= nil and not valid_actions[params.action] then + error("Invalid action: " .. params.action .. ". Valid actions are: create and drop.") + end + + params.action = params.action or "create" + + local module = $base_module:new({ name = "postgresql_user" }) + + module.params = $params + + module.is_exists = function(self) + self.ssh:requires("psql") + local result = self.ssh:cmdq("psql -tAc \"SELECT EXISTS(SELECT 1 FROM pg_roles WHERE rolname = '" .. self.params.name .. "')::int;\"") + if result.exit_code ~= 0 then + error(result.stderr) + end + if result.stdout == "1" then + return true + end + return false + end + + module.dry_run = function(self) + if self.params.action == "create" then + if self:is_exists() then + self.ssh:set_changed(false) + end + elseif self.params.action == "drop" then + if not self:is_exists() then + self.ssh:set_changed(false) + end + end + end + + module.run = function(self) + local query = "" + if self.params.action == "create" then + query = "CREATE USER " .. self.params.name + if self.params.role_attr_flags ~= nil or self.params.password ~= nil then + query = query .. " WITH " + if self.params.role_attr_flags ~= nil then + query = query .. " " .. self.params.role_attr_flags + end + if self.params.password ~= nil then + query = query .. " PASSWORD '" .. self.params.password .. "'" + end + end + elseif self.params.action == "drop" then + query = "DROP ROLE " .. self.params.name + end + query = query .. ";" + + if self.params.action == "create" then + if not self:is_exists() then + self.ssh:cmdq("psql -c \"" .. query .. "\"") + else + self.ssh:set_changed(false) + end + elseif self.params.action == "drop" then + if self:is_exists() then + self.ssh:cmdq("psql -c \"" .. query .. "\"") + else + self.ssh:set_changed(false) + end + end + end + + return module + }) + .set_name("postgresql_user") + .eval::
() + .into_lua_err()?; + + Ok(module) +} + +// Tests +#[cfg(test)] +mod tests { + use super::*; + use mlua::Lua; + + fn setup_lua() -> Lua { + Lua::new() + } + + #[test] + fn test_postgresql_user_requires_name_parameter() { + let lua = setup_lua(); + let params = lua.create_table().unwrap(); + + let result = postgresql_user(&lua, params); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("'name' parameter is required")); + } + + #[test] + fn test_postgresql_user_validates_action_parameter() { + let lua = setup_lua(); + let params = lua.create_table().unwrap(); + params.set("name", "test_user").unwrap(); + params.set("action", "invalid_action").unwrap(); + + let result = postgresql_user(&lua, params); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Invalid action")); + } + + #[test] + fn test_postgresql_user_defaults_to_create_action() { + let lua = setup_lua(); + let params = lua.create_table().unwrap(); + params.set("name", "test_user").unwrap(); + + let result = postgresql_user(&lua, params); + assert!(result.is_ok()); + let module = result.unwrap(); + let action: String = module + .get::
("params") + .unwrap() + .get("action") + .unwrap(); + assert_eq!(action, "create"); + } +} diff --git a/src/modules/script.rs b/src/modules/script.rs index 360ea3e..e4ee06b 100644 --- a/src/modules/script.rs +++ b/src/modules/script.rs @@ -8,7 +8,7 @@ pub fn script(lua: &Lua, params: Table) -> mlua::Result
{ .take(10) .collect(); - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.script == nil and params.from_file == nil then diff --git a/src/modules/systemd_service.rs b/src/modules/systemd_service.rs index fc924b9..3e62ea0 100644 --- a/src/modules/systemd_service.rs +++ b/src/modules/systemd_service.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn systemd_service(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.name == nil then diff --git a/src/modules/template.rs b/src/modules/template.rs index fe0c2c6..7476132 100644 --- a/src/modules/template.rs +++ b/src/modules/template.rs @@ -40,7 +40,7 @@ pub fn template(lua: &Lua, params: Table) -> mlua::Result
{ .take(10) .collect(); - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "template" }) diff --git a/src/modules/upload.rs b/src/modules/upload.rs index 1e0c3a3..7ab46be 100644 --- a/src/modules/upload.rs +++ b/src/modules/upload.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn upload(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "upload" }) diff --git a/src/report.rs b/src/report.rs new file mode 100644 index 0000000..90fe0a7 --- /dev/null +++ b/src/report.rs @@ -0,0 +1,116 @@ +use std::{ + collections::HashMap, + sync::{Mutex, OnceLock}, +}; + +use clap::Parser; + +use crate::args::Args; + +static REPORT: OnceLock>> = OnceLock::new(); + +fn get_report() -> &'static Mutex> { + REPORT.get_or_init(|| Mutex::new(Vec::new())) +} + +pub fn insert_record(task: String, host: String, status: TaskStatus) { + let record = ReportRecord { task, host, status }; + let report = get_report(); + report.lock().unwrap().push(record); +} + +pub fn generate_report() { + let report = get_report().lock().unwrap(); + if report.is_empty() { + return; + } + let width = 80; + let col2_width = 8; + let col1_width = width - col2_width - 2; + println!(); + println!("{:=^width$}", " Komando Report "); + if Args::parse().dry_run { + println!("{:-^width$}", " Dry-run mode: no changes were made "); + } + println!("{:col2_width$}", "Task on Host", "Status"); + println!("{:-() + ); + } + let col1_width = col1_width - 3; + println!(" - {:) -> std::fmt::Result { + match self { + TaskStatus::OK => write!(f, "OK"), + TaskStatus::Changed => write!(f, "Changed"), + TaskStatus::Failed => write!(f, "Failed"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_insert_record() { + insert_record("task1".to_string(), "host1".to_string(), TaskStatus::OK); + insert_record( + "task1".to_string(), + "host2".to_string(), + TaskStatus::Changed, + ); + insert_record("task2".to_string(), "host1".to_string(), TaskStatus::Failed); + + let report = get_report().lock().unwrap(); + assert_eq!(report.len(), 3); + assert_eq!(report[0].task, "task1"); + assert_eq!(report[0].host, "host1"); + assert_eq!(report[0].status, TaskStatus::OK); + assert_eq!(report[1].task, "task1"); + assert_eq!(report[1].host, "host2"); + assert_eq!(report[1].status, TaskStatus::Changed); + assert_eq!(report[2].task, "task2"); + assert_eq!(report[2].host, "host1"); + assert_eq!(report[2].status, TaskStatus::Failed); + } +} diff --git a/src/ssh.rs b/src/ssh.rs index 7631203..8b2f16a 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -6,8 +6,8 @@ use std::{ path::Path, }; -use anyhow::Result; -use mlua::UserData; +use anyhow::{Error, Result}; +use mlua::{Error::RuntimeError, UserData, Value}; use ssh2::{CheckResult, KnownHostFileKind, Session, Sftp}; #[derive(Debug, PartialEq)] @@ -41,6 +41,7 @@ pub struct SSHSession { stdout: Option, stderr: Option, exit_code: Option, + changed: Option, } impl SSHSession { @@ -56,6 +57,7 @@ impl SSHSession { stdout: Some(String::new()), stderr: Some(String::new()), exit_code: Some(0), + changed: Some(true), }) } @@ -71,42 +73,39 @@ impl SSHSession { self.session.set_tcp_stream(tcp); self.session.handshake()?; - match &self.known_hosts_file { - Some(file) => { - let host_key = self.session.host_key().unwrap(); - let mut known_hosts = self.session.known_hosts()?; - match known_hosts.read_file(Path::new(file.as_str()), KnownHostFileKind::OpenSSH) { - Ok(_) => {} - Err(_) => { - return Err(anyhow::Error::msg( - format!("SSH host key verification failed. Please add the host key to the known_hosts file: {}", file), - )); - } - }; - - let known_hosts_check_result = known_hosts.check(&address, &host_key.0); - match known_hosts_check_result { - CheckResult::Match => {} - _ => { - return Err(anyhow::Error::msg( - format!("SSH host key verification failed ({:?}). Please check the known_hosts file: {}", known_hosts_check_result, file), - )); - } - }; - } - None => {} + if let Some(file) = &self.known_hosts_file { + let host_key = self.session.host_key().unwrap(); + let mut known_hosts = self.session.known_hosts()?; + match known_hosts.read_file(Path::new(file.as_str()), KnownHostFileKind::OpenSSH) { + Ok(_) => {} + Err(_) => { + return Err(Error::msg( + format!("SSH host key verification failed. Please add the host key to the known_hosts file: {}", file) + )); + } + }; + + let known_hosts_check_result = known_hosts.check(address, host_key.0); + match known_hosts_check_result { + CheckResult::Match => {} + _ => { + return Err(Error::msg( + format!("SSH host key verification failed ({:?}). Please check the known_hosts file: {}", known_hosts_check_result, file) + )); + } + }; } match auth_method { SSHAuthMethod::Password(password) => { - self.session.userauth_password(&username, &password)?; + self.session.userauth_password(username, &password)?; } SSHAuthMethod::PublicKey { private_key, passphrase, } => { self.session.userauth_pubkey_file( - &username, + username, None, Path::new(&private_key), passphrase.as_deref(), @@ -115,7 +114,7 @@ impl SSHSession { } if !self.session.authenticated() { - return Err(anyhow::Error::msg("SSH authentication failed.")); + return Err(Error::msg("SSH authentication failed.")); } Ok(()) @@ -153,6 +152,20 @@ impl SSHSession { Ok((stdout, stderr, exit_code)) } + pub fn cmdq(&mut self, command: &str) -> Result<(String, String, i32)> { + let mut channel = self.execute_command(command)?; + let mut stdout = String::new(); + let mut stderr = String::new(); + + channel.read_to_string(&mut stdout)?; + channel.stderr().read_to_string(&mut stderr)?; + stdout = stdout.trim_end_matches('\n').to_string(); + channel.wait_close()?; + let exit_code = channel.exit_status()?; + + Ok((stdout, stderr, exit_code)) + } + pub fn prepare_command(&mut self, command: &str) -> Result { let command = match self.elevation.method { ElevationMethod::Su => match &self.elevation.as_user { @@ -160,8 +173,8 @@ impl SSHSession { None => format!("su -c '{}'", command), }, ElevationMethod::Sudo => match &self.elevation.as_user { - Some(user) => format!("sudo -u {} {}", user, command), - None => format!("sudo {}", command), + Some(user) => format!("sudo -E -u {} {}", user, command), + None => format!("sudo -E {}", command), }, _ => command.to_string(), }; @@ -224,7 +237,7 @@ impl SSHSession { let mut remote_file = self.session .scp_send(Path::new(remote_path), 0o644, content_length, None)?; - remote_file.write(content)?; + remote_file.write_all(content)?; remote_file.send_eof()?; remote_file.wait_eof()?; remote_file.close()?; @@ -244,7 +257,7 @@ fn upload_file(sftp: &mut Sftp, local_path: &Path, remote_path: &Path) -> io::Re } fn upload_directory(sftp: &mut Sftp, local_path: &Path, remote_path: &Path) -> io::Result<()> { - if !sftp.stat(remote_path).is_ok() { + if sftp.stat(remote_path).is_err() { sftp.mkdir(remote_path, 0o755)?; } @@ -312,6 +325,57 @@ impl UserData for SSHSession { Ok(table) }); + methods.add_method_mut("cmdq", |lua, this, command: String| { + let command = this.prepare_command(command.as_str())?; + let cmd_result = this.cmdq(&command); + let (stdout, stderr, exit_code) = cmd_result?; + + let table = lua.create_table()?; + table.set("stdout", stdout)?; + table.set("stderr", stderr)?; + table.set("exit_code", exit_code)?; + + Ok(table) + }); + + methods.add_method_mut("requires", |_, this, commands: Value| { + if !commands.is_table() && !commands.is_string() { + return Err(RuntimeError( + "'requires' must be called with a string or table".to_string(), + )) + } + + let commands = if commands.is_string() { + commands.to_string()? + } else { + let commands_table = commands.as_table().unwrap(); + let mut strings = String::new(); + for i in 1..= commands_table.len()? { + let s = commands_table.get::(i)?; + strings.push_str(&s); + if i < commands_table.len()? { + strings.push(' '); + } + } + strings + }; + + let command = this.prepare_command(format!("cmds=\"{}\"; unavailable=\"\"; for cmd in $(echo \"$cmds\"); do command -v \"$cmd\" >/dev/null 2>&1 || unavailable=\"$unavailable, $cmd\"; done; [ -z \"$unavailable\" ] || {{ echo \"${{unavailable#, }}\"; false; }}", commands).as_str())?; + let cmd_result = this.cmdq(&command); + let (stdout, _, exit_code) = cmd_result?; + + if exit_code != 0 { + return Err(RuntimeError( + format!( + "required commands not found on the remote host: {}", + stdout + ), + )) + } + + Ok(()) + }); + methods.add_method_mut( "write_remote_file", |_, this, (remote_path, content): (String, String)| { @@ -357,11 +421,17 @@ impl UserData for SSHSession { Ok(()) }); - methods.add_method("get_session_results", |lua, this, ()| { + methods.add_method_mut("set_changed", |_, this, changed: bool| { + this.changed = Some(changed); + Ok(()) + }); + + methods.add_method("get_session_result", |lua, this, ()| { let table = lua.create_table()?; table.set("stdout", this.stdout.as_ref().unwrap().clone())?; table.set("stderr", this.stderr.as_ref().unwrap().clone())?; table.set("exit_code", this.exit_code.unwrap())?; + table.set("changed", this.changed.unwrap())?; Ok(table) }); } @@ -399,13 +469,13 @@ mod tests { session.elevation.method = ElevationMethod::Sudo; session.elevation.as_user = None; let cmd = session.prepare_command("ls -la").unwrap(); - assert_eq!(cmd, "sudo ls -la"); + assert_eq!(cmd, "sudo -E ls -la"); // Test with sudo elevation and user session.elevation.method = ElevationMethod::Sudo; session.elevation.as_user = Some("admin".to_string()); let cmd = session.prepare_command("ls -la").unwrap(); - assert_eq!(cmd, "sudo -u admin ls -la"); + assert_eq!(cmd, "sudo -E -u admin ls -la"); // Test with su elevation session.elevation.method = ElevationMethod::Su; diff --git a/src/util.rs b/src/util.rs index 6296d84..d06ef8a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -116,10 +116,10 @@ pub fn parse_hosts_json_file(lua: &Lua, path: Value) -> mlua::Result
{ Ok(_) => match parse_hosts_json(lua, content) { Ok(h) => h, Err(_) => { - return Err(RuntimeError(String::from(format!( + return Err(RuntimeError(format!( "Failed to parse JSON file from '{}'", path - )))); + ))); } }, Err(_) => return Err(RuntimeError(String::from("Failed to read JSON file"))), @@ -158,20 +158,14 @@ pub fn parse_hosts_json_url(lua: &Lua, url: Value) -> mlua::Result
{ String::from_utf8_lossy(&response.body).to_string() } Err(e) => { - return Err(RuntimeError(String::from(format!( - "Failed to fetch URL: {:?}", - e - )))); + return Err(RuntimeError(format!("Failed to fetch URL: {:?}", e))); } }; let hosts = match parse_hosts_json(lua, content) { Ok(h) => h, Err(_) => { - return Err(RuntimeError(String::from(format!( - "Failed to parse JSON from '{}'", - url - )))); + return Err(RuntimeError(format!("Failed to parse JSON from '{}'", url))); } }; @@ -209,12 +203,9 @@ fn parse_hosts_json(lua: &Lua, content: String) -> mlua::Result
{ for pair in lua_table.pairs() { let (_, value): (Value, Value) = pair?; - match validate_host(&lua, value) { - Ok(host) => { - hosts.set(hosts.len()? + 1, host)?; - } - Err(_) => {} - }; + if let Ok(host) = validate_host(lua, value) { + hosts.set(hosts.len()? + 1, host)?; + } } Ok(hosts) @@ -236,7 +227,7 @@ pub fn host_display(host: &Table) -> String { match host.get::("name") { Ok(name) => format!("{} ({})", name, address), - Err(_) => format!("{}", address), + Err(_) => address, } } @@ -264,6 +255,7 @@ mod tests { let args = Args { main_file: None, chunk: None, + dry_run: false, interactive: false, verbose: true, version: false, @@ -281,6 +273,7 @@ mod tests { let args = Args { main_file: None, chunk: None, + dry_run: false, interactive: false, verbose: false, version: false, diff --git a/src/validator.rs b/src/validator.rs index a8e96f1..823273d 100644 --- a/src/validator.rs +++ b/src/validator.rs @@ -2,15 +2,15 @@ use mlua::{chunk, Error::RuntimeError, ExternalResult, Integer, Lua, Table, Valu pub fn validate_host(lua: &Lua, host: Value) -> mlua::Result
{ if !host.is_table() { - return Err(RuntimeError(format!("Host is not a table."))); + return Err(RuntimeError("Host is not a table.".to_string())); } let address = host.as_table().unwrap().get::("address")?; if address.is_nil() { - return Err(RuntimeError(format!("Host address is empty."))); + return Err(RuntimeError("Host address is empty.".to_string())); } if !address.is_string() { - return Err(RuntimeError(format!("Host address is invalid."))); + return Err(RuntimeError("Host address is invalid.".to_string())); } let port = host.as_table().unwrap().get::("port")?; @@ -245,7 +245,7 @@ mod tests { let result = super::validate_module(&lua, mlua::Value::String(lua.create_string("ls").unwrap())); - eprint!("result: {:#?}\n", result.clone().err()); + eprintln!("result: {:#?}", result.clone().err()); assert!(result.is_ok()); } diff --git a/tests/komando.rs b/tests/komando.rs index bfddfe5..b885698 100644 --- a/tests/komando.rs +++ b/tests/komando.rs @@ -249,3 +249,30 @@ fn test_komando_script_from_file() { assert!(result_table.get::("stdout").unwrap() == "hello"); assert!(result_table.get::("stderr").unwrap() == ""); } + +#[test] +fn test_komando_apt() { + let lua = create_lua().unwrap(); + + let result_table = lua + .load(chunk! { + local hosts = { + address = "localhost", + user = "usertest", + private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519" + } + + local task = { + komandan.modules.apt({ + package = "tar", + }), + elevate = true + } + + return komandan.komando(hosts, task) + }) + .eval::
() + .unwrap(); + + assert!(result_table.get::("exit_code").unwrap() == 0); +}