From c8044b519fed328e06a2b74f3c0429696acc50bb Mon Sep 17 00:00:00 2001 From: Abdul Munif Hanafi Date: Tue, 31 Dec 2024 16:36:11 +0700 Subject: [PATCH] Refactor and tidy up code --- Cargo.lock | 87 +++++-------- Cargo.toml | 11 +- src/defaults.rs | 219 +++++++++++++------------------ src/lib.rs | 232 ++++++++++++++++----------------- src/main.rs | 3 +- src/modules/apt.rs | 2 +- src/modules/base.rs | 22 ++++ src/modules/cmd.rs | 2 +- src/modules/core.rs | 19 +++ src/modules/download.rs | 2 +- src/modules/lineinfile.rs | 2 +- src/modules/mod.rs | 64 +-------- src/modules/script.rs | 2 +- src/modules/systemd_service.rs | 2 +- src/modules/template.rs | 2 +- src/modules/upload.rs | 2 +- src/ssh.rs | 57 ++++---- src/util.rs | 25 ++-- src/validator.rs | 8 +- 19 files changed, 340 insertions(+), 423 deletions(-) create mode 100644 src/modules/base.rs create mode 100644 src/modules/core.rs diff --git a/Cargo.lock b/Cargo.lock index b06dcde..39a43cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "autocfg" @@ -129,9 +129,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "clap" -version = "4.5.21" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", "clap_derive", @@ -139,9 +139,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.21" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstream", "anstyle", @@ -158,14 +158,14 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "clipboard-win" @@ -347,7 +347,6 @@ name = "komandan" version = "0.1.0" dependencies = [ "anyhow", - "base64", "clap", "http-klien", "minijinja", @@ -452,9 +451,9 @@ dependencies = [ [[package]] name = "mlua" -version = "0.10.1" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ae9546e4a268c309804e8bbb7526e31cbfdedca7cd60ac1b987d0b212e0d876" +checksum = "9ea43c3ffac2d0798bd7128815212dd78c98316b299b7a902dabef13dc7b6b8d" dependencies = [ "anyhow", "bstr", @@ -471,9 +470,9 @@ dependencies = [ [[package]] name = "mlua-sys" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efa6bf1a64f06848749b7e7727417f4ec2121599e2a10ef0a8a3888b0e9a5a0d" +checksum = "63a11d485edf0f3f04a508615d36c7d50d299cf61a7ee6d3e2530651e0a31771" dependencies = [ "cc", "cfg-if", @@ -484,17 +483,17 @@ dependencies = [ [[package]] name = "mlua_derive" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cfc5faa2e0d044b3f5f0879be2920e0a711c97744c42cf1c295cb183668933e" +checksum = "870d71c172fcf491c6b5fb4c04160619a2ee3e5a42a1402269c66bcbf1dd4deb" dependencies = [ "itertools", "once_cell", - "proc-macro-error", + "proc-macro-error2", "proc-macro2", "quote", "regex", - "syn 2.0.89", + "syn", ] [[package]] @@ -556,7 +555,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] [[package]] @@ -654,27 +653,25 @@ dependencies = [ ] [[package]] -name = "proc-macro-error" -version = "1.0.4" +name = "proc-macro-error-attr2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" dependencies = [ - "proc-macro-error-attr", "proc-macro2", "quote", - "syn 1.0.109", - "version_check", ] [[package]] -name = "proc-macro-error-attr" -version = "1.0.4" +name = "proc-macro-error2" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" dependencies = [ + "proc-macro-error-attr2", "proc-macro2", "quote", - "version_check", + "syn", ] [[package]] @@ -857,9 +854,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] @@ -876,20 +873,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.134" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" dependencies = [ "itoa", "memchr", @@ -927,16 +924,6 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.89" @@ -997,12 +984,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1149,5 +1130,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn", ] diff --git a/Cargo.toml b/Cargo.toml index f7e5348..b079b2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,11 @@ edition = "2021" vendored-openssl = ["http-klien/vendored-openssl", "ssh2/vendored-openssl"] [dependencies] -anyhow = "1.0.93" -base64 = "0.22.1" -clap = { version = "4.5.21", features = ["derive"] } +anyhow = "1.0.95" +clap = { version = "4.5.23", features = ["derive"] } http-klien = { git = "https://github.com/hahnavi/http-klien-rs", branch = "main" } minijinja = "2.5.0" -mlua = { version = "0.10.1", features = [ +mlua = { version = "0.10.2", features = [ "anyhow", "luajit", "macros", @@ -25,8 +24,8 @@ rand = "0.8.5" rayon = "1.10.0" regex = "1.11.1" rustyline = "15.0.0" -serde = { version = "1.0.216", features = ["derive"] } -serde_json = "1.0.133" +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.134" ssh2 = { version = "0.9.4" } [dev-dependencies] diff --git a/src/defaults.rs b/src/defaults.rs index 2b01164..a92f967 100644 --- a/src/defaults.rs +++ b/src/defaults.rs @@ -1,10 +1,11 @@ +use anyhow::{Error, Result}; use mlua::UserData; use std::{ collections::HashMap, sync::{Arc, OnceLock, RwLock}, }; -static GLOBAL_STATE: OnceLock = OnceLock::new(); +static GLOBAL_DEFAULTS: OnceLock = OnceLock::new(); #[derive(Clone)] pub struct Defaults { @@ -23,16 +24,16 @@ pub struct Defaults { } impl Defaults { - pub fn new() -> Self { + pub fn new() -> Result { let env = Arc::new(RwLock::new(HashMap::new())); match env.write() { Ok(mut env) => { env.insert("DEBIAN_FRONTEND".to_string(), "noninteractive".to_string()); } - Err(_) => {} + Err(_) => return Err(Error::msg("Failed to acquire write lock".to_string())), } - Self { + Ok(Self { port: Arc::new(RwLock::new(22)), user: Arc::new(RwLock::new(None)), private_key_file: Arc::new(RwLock::new(None)), @@ -44,15 +45,17 @@ impl Defaults { as_user: Arc::new(RwLock::new(None)), known_hosts_file: Arc::new(RwLock::new(format!( "{}/.ssh/known_hosts", - std::env::var("HOME").unwrap() + std::env::var("HOME").unwrap_or("~".to_string()) ))), host_key_check: Arc::new(RwLock::new(true)), env, - } + }) } - pub fn global() -> Self { - GLOBAL_STATE.get_or_init(|| Defaults::new()).clone() + pub fn global() -> Result { + Ok(GLOBAL_DEFAULTS + .get_or_init(|| Defaults::new().unwrap()) + .clone()) } } @@ -61,11 +64,9 @@ impl UserData for Defaults { methods.add_method("get_port", |_, this, ()| -> mlua::Result { match this.port.read() { Ok(port) => Ok(*port), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -75,22 +76,18 @@ impl UserData for Defaults { *port = new_port; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }); methods.add_method("get_user", |_, this, ()| -> mlua::Result> { match this.user.read() { Ok(user) => Ok(user.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -102,11 +99,9 @@ impl UserData for Defaults { *user = new_user; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -116,11 +111,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.private_key_file.read() { Ok(private_key_file) => Ok(private_key_file.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -133,11 +126,9 @@ impl UserData for Defaults { *private_key_file = new_private_key_file; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -147,11 +138,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.private_key_pass.read() { Ok(private_key_pass) => Ok(private_key_pass.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -164,11 +153,9 @@ impl UserData for Defaults { *private_key_pass = new_private_key_pass; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -178,11 +165,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.password.read() { Ok(password) => Ok(password.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -195,11 +180,9 @@ impl UserData for Defaults { *password = new_password; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -209,11 +192,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result { match this.ignore_exit_code.read() { Ok(ignore_exit_code) => Ok(*ignore_exit_code), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -226,11 +207,9 @@ impl UserData for Defaults { *ignore_exit_code = new_ignore_exit_code; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -238,11 +217,9 @@ impl UserData for Defaults { methods.add_method("get_elevate", |_, this, ()| -> mlua::Result { match this.elevate.read() { Ok(elevate) => Ok(*elevate), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -254,11 +231,9 @@ impl UserData for Defaults { *elevate = new_elevate; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -268,11 +243,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result { match this.elevation_method.read() { Ok(elevation_method) => Ok(elevation_method.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -285,11 +258,9 @@ impl UserData for Defaults { *elevation_method = new_elevation_method; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -299,11 +270,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result> { match this.as_user.read() { Ok(as_user) => Ok(as_user.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -316,11 +285,9 @@ impl UserData for Defaults { *as_user = new_as_user; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -330,11 +297,9 @@ impl UserData for Defaults { |_, this, ()| -> mlua::Result { match this.known_hosts_file.read() { Ok(known_hosts_file) => Ok(known_hosts_file.clone()), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }, ); @@ -347,11 +312,9 @@ impl UserData for Defaults { *known_hosts_file = new_known_hosts_file; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -359,11 +322,9 @@ impl UserData for Defaults { methods.add_method("get_host_key_check", |_, this, ()| -> mlua::Result { match this.host_key_check.read() { Ok(host_key_check) => Ok(*host_key_check), - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire read lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire read lock".to_string(), + )), } }); @@ -375,11 +336,9 @@ impl UserData for Defaults { *host_key_check = new_host_key_check; Ok(()) } - Err(_) => { - return Err(mlua::Error::RuntimeError( - "Failed to acquire write lock".to_string(), - )) - } + Err(_) => Err(mlua::Error::RuntimeError( + "Failed to acquire write lock".to_string(), + )), } }, ); @@ -433,8 +392,8 @@ mod tests { use super::*; #[test] - fn test_defaults_new() { - let defaults = Defaults::new(); + fn test_defaults_new() -> Result<()> { + let defaults = Defaults::new()?; // Test default values assert_eq!(*defaults.port.read().unwrap(), 22); @@ -442,11 +401,11 @@ mod tests { assert_eq!(*defaults.private_key_file.read().unwrap(), None); assert_eq!(*defaults.private_key_pass.read().unwrap(), None); assert_eq!(*defaults.password.read().unwrap(), None); - assert_eq!(*defaults.ignore_exit_code.read().unwrap(), false); - assert_eq!(*defaults.elevate.read().unwrap(), false); + assert!(!(*defaults.ignore_exit_code.read().unwrap())); + assert!(!(*defaults.elevate.read().unwrap())); assert_eq!(*defaults.elevation_method.read().unwrap(), "sudo"); assert_eq!(*defaults.as_user.read().unwrap(), None); - assert_eq!(*defaults.host_key_check.read().unwrap(), true); + assert!(*defaults.host_key_check.read().unwrap()); // Test default environment variables let env = defaults.env.read().unwrap(); @@ -454,24 +413,28 @@ mod tests { env.get("DEBIAN_FRONTEND"), Some(&"noninteractive".to_string()) ); + + Ok(()) } #[test] - fn test_global_singleton() { - let defaults1 = Defaults::global(); - let defaults2 = Defaults::global(); + fn test_global_singleton() -> Result<()> { + let defaults1 = Defaults::global()?; + let defaults2 = Defaults::global()?; // Modify a value using the first instance *defaults1.port.write().unwrap() = 2222; // Check if the change is reflected in the second instance assert_eq!(*defaults2.port.read().unwrap(), 2222); + + Ok(()) } #[test] - fn test_lua_interface() -> mlua::Result<()> { + fn test_lua_interface() -> Result<()> { let lua = mlua::Lua::new(); - let defaults = Defaults::new(); + let defaults = Defaults::new()?; // Register the defaults instance with Lua lua.globals().set("defaults", defaults.clone())?; diff --git a/src/lib.rs b/src/lib.rs index b9cd43c..246c592 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod ssh; mod util; mod validator; +use anyhow::Result; use args::Args; use clap::Parser; use defaults::Defaults; @@ -13,7 +14,7 @@ use mlua::{ Error::{self, RuntimeError}, FromLua, Integer, IntoLua, Lua, LuaSerdeExt, MultiValue, Table, UserData, Value, }; -use modules::{base_module, collect_modules}; +use modules::{base_module, collect_core_modules}; use rayon::prelude::*; use rustyline::DefaultEditor; use serde::{Deserialize, Serialize}; @@ -62,10 +63,10 @@ pub fn create_lua() -> mlua::Result { pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { let komandan = lua.create_table()?; - let defaults = Defaults::global(); + let defaults = Defaults::global()?; komandan.set("defaults", defaults)?; - let base_module = base_module(&lua); + let base_module = base_module(lua)?; komandan.set("KomandanModule", base_module)?; komandan.set("komando", lua.create_function(komando)?)?; @@ -92,7 +93,7 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { komandan.set("dprint", lua.create_function(dprint)?)?; // Add core modules - komandan.set("modules", collect_modules(lua))?; + komandan.set("modules", collect_core_modules(lua)?)?; lua.globals().set("komandan", &komandan)?; @@ -100,7 +101,7 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { } fn get_user(host: &Table, task: &Table) -> mlua::Result { - let defaults = Defaults::global(); + let defaults = Defaults::global()?; let default_user = match defaults.user.read() { Ok(user) => user, Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), @@ -130,7 +131,7 @@ fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthM let user = get_user(host, task)?; - let defaults = Defaults::global(); + let defaults = Defaults::global()?; let default_private_key_file = match defaults.private_key_file.read() { Ok(private_key_file) => private_key_file, @@ -152,10 +153,7 @@ fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthM private_key: private_key_file, passphrase: match host.get::("private_key_pass") { Ok(passphrase) => Some(passphrase), - Err(_) => match *default_private_key_pass { - Some(ref private_key_pass) => Some(private_key_pass.clone()), - None => None, - }, + Err(_) => (*default_private_key_pass).clone(), }, }, Err(_) => match *default_private_key_file { @@ -163,10 +161,7 @@ fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthM private_key: private_key_file.clone(), passphrase: match host.get::("private_key_pass") { Ok(passphrase) => Some(passphrase), - Err(_) => match *default_private_key_pass { - Some(ref passphrase) => Some(passphrase.clone()), - None => None, - }, + Err(_) => (*default_private_key_pass).clone(), }, }, None => match host.get::("password") { @@ -188,7 +183,7 @@ fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthM } fn get_elevation_config(host: &Table, task: &Table) -> mlua::Result { - let defaults = Defaults::global(); + let defaults = Defaults::global()?; let default_elevate = match defaults.elevate.read() { Ok(elevate) => elevate, @@ -250,7 +245,7 @@ fn get_elevation_config(host: &Table, task: &Table) -> mlua::Result { } fn setup_ssh_session(host: &Table) -> mlua::Result { - let defaults = Defaults::global(); + let defaults = Defaults::global()?; let mut ssh = SSHSession::new()?; let default_host_key_check = match defaults.host_key_check.read() { @@ -260,7 +255,7 @@ fn setup_ssh_session(host: &Table) -> mlua::Result { let host_key_check = match host.get::("host_key_check") { Ok(host_key_check) => match host_key_check { - Value::Nil => default_host_key_check.clone(), + Value::Nil => *default_host_key_check, Value::Boolean(false) => false, _ => true, }, @@ -283,7 +278,7 @@ fn setup_ssh_session(host: &Table) -> mlua::Result { } fn setup_environment(ssh: &mut SSHSession, host: &Table, task: &Table) -> mlua::Result<()> { - let defaults = Defaults::global(); + let defaults = Defaults::global()?; let default_env = match defaults.env.read() { Ok(env) => env, @@ -297,14 +292,14 @@ fn setup_environment(ssh: &mut SSHSession, host: &Table, task: &Table) -> mlua:: ssh.set_env(&key, &value); } - if !env_host.is_none() { + if env_host.is_some() { for pair in env_host.unwrap().pairs() { let (key, value): (String, String) = pair?; ssh.set_env(&key, &value); } } - if !env_task.is_none() { + if env_task.is_some() { for pair in env_task.unwrap().pairs() { let (key, value): (String, String) = pair?; ssh.set_env(&key, &value); @@ -368,7 +363,7 @@ fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result { let host_display = host_display(&host); let task_display = task_display(&task); - let defaults = Defaults::global(); + let defaults = Defaults::global()?; let (user, ssh_auth_method) = get_auth_config(&host, &task)?; let elevation = get_elevation_config(&host, &task)?; @@ -378,9 +373,7 @@ fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ Err(_) => return Err(RuntimeError("Failed to acquire read lock".to_string())), }; - let port = host - .get::("port") - .unwrap_or(default_port.clone() as i64) as u16; + let port = host.get::("port").unwrap_or(*default_port as i64) as u16; let mut ssh = setup_ssh_session(&host)?; ssh.elevation = elevation; @@ -403,9 +396,9 @@ fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ let ignore_exit_code = task .get::("ignore_exit_code") - .unwrap_or(default_ignore_exit_code.clone()); + .unwrap_or(*default_ignore_exit_code); - if results.get::("exit_code").unwrap() != 0 && !ignore_exit_code { + if results.get::("exit_code")? != 0 && !ignore_exit_code { return Err(RuntimeError("Failed to run task.".to_string())); } @@ -630,11 +623,11 @@ fn komando_parallel_tasks(lua: &Lua, (host, tasks): (Value, Value)) -> mlua::Res let task = task.clone().into_lua(&lua).unwrap(); let result = komando(&lua, (host, task)).unwrap(); - return ( + ( *i, lua.from_value::(Value::Table(result)) .unwrap(), - ); + ) }) .collect::>(); @@ -665,11 +658,11 @@ fn komando_parallel_hosts(lua: &Lua, (hosts, task): (Value, Value)) -> mlua::Res let task = task.clone().into_lua(&lua).unwrap(); let result = komando(&lua, (host, task)).unwrap(); - return ( + ( *i, lua.from_value::(Value::Table(result)) .unwrap(), - ); + ) }) .collect::>(); @@ -683,7 +676,7 @@ fn komando_parallel_hosts(lua: &Lua, (hosts, task): (Value, Value)) -> mlua::Res Ok(results_table) } -pub fn run_main_file(lua: &Lua, main_file: &String) -> anyhow::Result<()> { +pub fn run_main_file(lua: &Lua, main_file: &String) -> Result<()> { let script = match fs::read_to_string(main_file) { Ok(script) => script, Err(e) => { @@ -731,7 +724,7 @@ pub fn repl(lua: &Lua) { incomplete_input: true, .. }) => { - line.push_str("\n"); + line.push('\n'); prompt = ">> "; } Err(e) => { @@ -772,65 +765,66 @@ mod tests { } #[test] - fn test_setup_komandan_table() { - let lua = create_lua().unwrap(); + fn test_setup_komandan_table() -> Result<()> { + let lua = create_lua()?; // Assert that the komandan table is set up correctly - let komandan_table = lua.globals().get::
("komandan").unwrap(); - assert!(komandan_table.contains_key("defaults").unwrap()); - assert!(komandan_table.contains_key("KomandanModule").unwrap()); - assert!(komandan_table.contains_key("komando").unwrap()); - assert!(komandan_table.contains_key("regex_is_match").unwrap()); - assert!(komandan_table.contains_key("filter_hosts").unwrap()); - assert!(komandan_table - .contains_key("parse_hosts_json_file") - .unwrap()); - assert!(komandan_table.contains_key("parse_hosts_json_url").unwrap()); - assert!(komandan_table.contains_key("dprint").unwrap()); - - let modules_table = komandan_table.get::
("modules").unwrap(); - assert!(modules_table.contains_key("apt").unwrap()); - assert!(modules_table.contains_key("cmd").unwrap()); - assert!(modules_table.contains_key("lineinfile").unwrap()); - assert!(modules_table.contains_key("script").unwrap()); - assert!(modules_table.contains_key("systemd_service").unwrap()); - assert!(modules_table.contains_key("template").unwrap()); - assert!(modules_table.contains_key("upload").unwrap()); - assert!(modules_table.contains_key("download").unwrap()); + let komandan_table = lua.globals().get::
("komandan")?; + assert!(komandan_table.contains_key("defaults")?); + assert!(komandan_table.contains_key("KomandanModule")?); + assert!(komandan_table.contains_key("komando")?); + assert!(komandan_table.contains_key("regex_is_match")?); + assert!(komandan_table.contains_key("filter_hosts")?); + assert!(komandan_table.contains_key("parse_hosts_json_file")?); + assert!(komandan_table.contains_key("parse_hosts_json_url")?); + assert!(komandan_table.contains_key("dprint")?); + + let modules_table = komandan_table.get::
("modules")?; + assert!(modules_table.contains_key("apt")?); + assert!(modules_table.contains_key("cmd")?); + assert!(modules_table.contains_key("lineinfile")?); + assert!(modules_table.contains_key("script")?); + assert!(modules_table.contains_key("systemd_service")?); + assert!(modules_table.contains_key("template")?); + assert!(modules_table.contains_key("upload")?); + assert!(modules_table.contains_key("download")?); + + Ok(()) } #[test] - fn test_run_main_file() { - let lua = create_lua().unwrap(); + fn test_run_main_file() -> Result<()> { + let lua = create_lua()?; // Test with a valid Lua file let main_file = "examples/hosts.lua".to_string(); let result = run_main_file(&lua, &main_file); assert!(result.is_ok()); + + Ok(()) } #[test] - fn test_get_auth_config() { - let lua = create_lua().unwrap(); - let host = lua.create_table().unwrap(); + fn test_get_auth_config() -> Result<()> { + let lua = create_lua()?; + let host = lua.create_table()?; // Test with user in host - host.set("address", "localhost").unwrap(); - host.set("user", "testuser").unwrap(); - host.set("private_key_file", "/path/to/key").unwrap(); + host.set("address", "localhost")?; + host.set("user", "testuser")?; + host.set("private_key_file", "/path/to/key")?; - let module_params = lua.create_table().unwrap(); - module_params.set("cmd", "echo test").unwrap(); + let module_params = lua.create_table()?; + module_params.set("cmd", "echo test")?; let module = lua .load(chunk! { return komandan.modules.cmd($module_params) }) - .eval::
() - .unwrap(); - let task = lua.create_table().unwrap(); - task.set(1, module).unwrap(); + .eval::
()?; + let task = lua.create_table()?; + task.set(1, module)?; - let (user, auth) = get_auth_config(&host, &task).unwrap(); + let (user, auth) = get_auth_config(&host, &task)?; assert_eq!(user, "testuser"); match auth { SSHAuthMethod::PublicKey { @@ -844,28 +838,30 @@ mod tests { } // Test with password auth - host.set("private_key_file", Value::Nil).unwrap(); - host.set("password", "testpass").unwrap(); - let (_, auth) = get_auth_config(&host, &task).unwrap(); + host.set("private_key_file", Value::Nil)?; + host.set("password", "testpass")?; + let (_, auth) = get_auth_config(&host, &task)?; match auth { SSHAuthMethod::Password(pass) => assert_eq!(pass, "testpass"), _ => panic!("Expected Password authentication"), } // Test with no authentication method - host.set("password", Value::Nil).unwrap(); + host.set("password", Value::Nil)?; let result = get_auth_config(&host, &task); assert!(result.is_err()); + + Ok(()) } #[test] - fn test_get_elevation_config() { - let lua = create_lua().unwrap(); - let host = lua.create_table().unwrap(); - let task = lua.create_table().unwrap(); + fn test_get_elevation_config() -> Result<()> { + let lua = create_lua()?; + let host = lua.create_table()?; + let task = lua.create_table()?; // Test with no elevation - let elevation = get_elevation_config(&host, &task).unwrap(); + let elevation = get_elevation_config(&host, &task)?; assert!(matches!( elevation, Elevation { @@ -875,8 +871,8 @@ mod tests { )); // Test with elevation from task - task.set("elevate", true).unwrap(); - let elevation = get_elevation_config(&host, &task).unwrap(); + task.set("elevate", true)?; + let elevation = get_elevation_config(&host, &task)?; assert!(matches!( elevation, Elevation { @@ -886,8 +882,8 @@ mod tests { )); // Test with custom elevation method - task.set("elevation_method", "su").unwrap(); - let elevation = get_elevation_config(&host, &task).unwrap(); + task.set("elevation_method", "su")?; + let elevation = get_elevation_config(&host, &task)?; assert!(matches!( elevation, Elevation { @@ -897,65 +893,69 @@ mod tests { )); // Test invalid elevation method - task.set("elevation_method", "invalid").unwrap(); + task.set("elevation_method", "invalid")?; assert!(get_elevation_config(&host, &task).is_err()); + + Ok(()) } #[test] - fn test_setup_ssh_session() { - let lua = create_lua().unwrap(); - let host = lua.create_table().unwrap(); - host.set("address", "localhost").unwrap(); + fn test_setup_ssh_session() -> Result<()> { + let lua = create_lua()?; + let host = lua.create_table()?; + host.set("address", "localhost")?; // Test with default settings - let ssh = setup_ssh_session(&host).unwrap(); + let ssh = setup_ssh_session(&host)?; assert!(ssh.known_hosts_file.is_some()); // Test with host key check disabled - host.set("host_key_check", false).unwrap(); - let ssh = setup_ssh_session(&host).unwrap(); + host.set("host_key_check", false)?; + let ssh = setup_ssh_session(&host)?; assert!(ssh.known_hosts_file.is_none()); // Test with custom known_hosts file - host.set("known_hosts_file", "/path/to/known_hosts") - .unwrap(); - host.set("host_key_check", true).unwrap(); - let ssh = setup_ssh_session(&host).unwrap(); + host.set("known_hosts_file", "/path/to/known_hosts")?; + host.set("host_key_check", true)?; + let ssh = setup_ssh_session(&host)?; assert_eq!(ssh.known_hosts_file.unwrap(), "/path/to/known_hosts"); // Test with known_hosts from defaults - host.set("known_hosts_file", Value::Nil).unwrap(); + host.set("known_hosts_file", Value::Nil)?; lua.load(chunk! { komandan.defaults:set_known_hosts_file("/default/known_hosts") }) - .exec() - .unwrap(); - let ssh = setup_ssh_session(&host).unwrap(); + .exec()?; + let ssh = setup_ssh_session(&host)?; assert_eq!(ssh.known_hosts_file.unwrap(), "/default/known_hosts"); + + Ok(()) } #[test] - fn test_setup_environment() { - let lua = create_lua().unwrap(); - let mut ssh = SSHSession::new().unwrap(); - let defaults = lua.create_table().unwrap(); - let host = lua.create_table().unwrap(); - let task = lua.create_table().unwrap(); + fn test_setup_environment() -> Result<()> { + let lua = create_lua()?; + let mut ssh = SSHSession::new()?; + let defaults = lua.create_table()?; + let host = lua.create_table()?; + let task = lua.create_table()?; // Test with environment variables at all levels - let env_defaults = lua.create_table().unwrap(); - env_defaults.set("DEFAULT_VAR", "default_value").unwrap(); - defaults.set("env", env_defaults).unwrap(); + let env_defaults = lua.create_table()?; + env_defaults.set("DEFAULT_VAR", "default_value")?; + defaults.set("env", env_defaults)?; + + let env_host = lua.create_table()?; + env_host.set("HOST_VAR", "host_value")?; + env_host.set("DEFAULT_VAR", "overridden_value")?; // Override default + host.set("env", env_host)?; - let env_host = lua.create_table().unwrap(); - env_host.set("HOST_VAR", "host_value").unwrap(); - env_host.set("DEFAULT_VAR", "overridden_value").unwrap(); // Override default - host.set("env", env_host).unwrap(); + let env_task = lua.create_table()?; + env_task.set("TASK_VAR", "task_value")?; + task.set("env", env_task)?; - let env_task = lua.create_table().unwrap(); - env_task.set("TASK_VAR", "task_value").unwrap(); - task.set("env", env_task).unwrap(); + setup_environment(&mut ssh, &host, &task)?; - setup_environment(&mut ssh, &host, &task).unwrap(); + Ok(()) } } diff --git a/src/main.rs b/src/main.rs index 15b422b..ebec3a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ mod args; +use anyhow::Result; use args::Args; use clap::Parser; use komandan::{create_lua, print_version, repl, run_main_file}; -fn main() -> anyhow::Result<()> { +fn main() -> Result<()> { let args = Args::parse(); if args.version { diff --git a/src/modules/apt.rs b/src/modules/apt.rs index a98921b..b5d320a 100644 --- a/src/modules/apt.rs +++ b/src/modules/apt.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn apt(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.update_cache == nil then diff --git a/src/modules/base.rs b/src/modules/base.rs new file mode 100644 index 0000000..f86d982 --- /dev/null +++ b/src/modules/base.rs @@ -0,0 +1,22 @@ +use mlua::{chunk, Table}; + +pub fn base_module(lua: &mlua::Lua) -> mlua::Result
{ + lua.load(chunk! { + local KomandanModule = {} + + KomandanModule.new = function(self,data) + local o = setmetatable({}, { __index = self }) + o.name = data.name + return o + end + + KomandanModule.run = function(self) + end + + KomandanModule.cleanup = function(self) + end + + return KomandanModule + }) + .eval::
() +} diff --git a/src/modules/cmd.rs b/src/modules/cmd.rs index 68cfc4a..86386dd 100644 --- a/src/modules/cmd.rs +++ b/src/modules/cmd.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn cmd(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "cmd" }) diff --git a/src/modules/core.rs b/src/modules/core.rs new file mode 100644 index 0000000..8edb8f0 --- /dev/null +++ b/src/modules/core.rs @@ -0,0 +1,19 @@ +use mlua::Table; + +use super::*; + +pub fn collect_core_modules(lua: &mlua::Lua) -> mlua::Result
{ + let modules = lua.create_table()?; + modules.set("apt", lua.create_function(apt::apt)?)?; + modules.set("cmd", lua.create_function(cmd::cmd)?)?; + modules.set("download", lua.create_function(download::download)?)?; + modules.set("lineinfile", lua.create_function(lineinfile::lineinfile)?)?; + modules.set("script", lua.create_function(script::script)?)?; + modules.set( + "systemd_service", + lua.create_function(systemd_service::systemd_service)?, + )?; + modules.set("template", lua.create_function(template::template)?)?; + modules.set("upload", lua.create_function(upload::upload)?)?; + Ok(modules) +} diff --git a/src/modules/download.rs b/src/modules/download.rs index 0873301..8bf954d 100644 --- a/src/modules/download.rs +++ b/src/modules/download.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn download(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "download" }) diff --git a/src/modules/lineinfile.rs b/src/modules/lineinfile.rs index b14bdf5..9d073a0 100644 --- a/src/modules/lineinfile.rs +++ b/src/modules/lineinfile.rs @@ -8,7 +8,7 @@ pub fn lineinfile(lua: &Lua, params: Table) -> mlua::Result
{ .take(10) .collect(); - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.path == nil then diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 6e766d7..0d0b3dc 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -1,5 +1,7 @@ mod apt; +mod base; mod cmd; +mod core; mod download; mod lineinfile; mod script; @@ -7,63 +9,5 @@ mod systemd_service; mod template; mod upload; -use mlua::{chunk, Table}; - -pub fn base_module(lua: &mlua::Lua) -> Table { - return lua - .load(chunk! { - local KomandanModule = {} - - KomandanModule.new = function(self,data) - local o = setmetatable({}, { __index = self }) - o.name = data.name - return o - end - - KomandanModule.run = function(self) - end - - KomandanModule.cleanup = function(self) - end - - return KomandanModule - }) - .eval::
() - .unwrap(); -} - -pub fn collect_modules(lua: &mlua::Lua) -> Table { - let modules = lua.create_table().unwrap(); - modules - .set("apt", lua.create_function(apt::apt).unwrap()) - .unwrap(); - modules - .set("cmd", lua.create_function(cmd::cmd).unwrap()) - .unwrap(); - modules - .set("download", lua.create_function(download::download).unwrap()) - .unwrap(); - modules - .set( - "lineinfile", - lua.create_function(lineinfile::lineinfile).unwrap(), - ) - .unwrap(); - modules - .set("script", lua.create_function(script::script).unwrap()) - .unwrap(); - modules - .set( - "systemd_service", - lua.create_function(systemd_service::systemd_service) - .unwrap(), - ) - .unwrap(); - modules - .set("template", lua.create_function(template::template).unwrap()) - .unwrap(); - modules - .set("upload", lua.create_function(upload::upload).unwrap()) - .unwrap(); - return modules; -} +pub use base::*; +pub use core::*; diff --git a/src/modules/script.rs b/src/modules/script.rs index 360ea3e..e4ee06b 100644 --- a/src/modules/script.rs +++ b/src/modules/script.rs @@ -8,7 +8,7 @@ pub fn script(lua: &Lua, params: Table) -> mlua::Result
{ .take(10) .collect(); - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.script == nil and params.from_file == nil then diff --git a/src/modules/systemd_service.rs b/src/modules/systemd_service.rs index fc924b9..3e62ea0 100644 --- a/src/modules/systemd_service.rs +++ b/src/modules/systemd_service.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn systemd_service(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { if params.name == nil then diff --git a/src/modules/template.rs b/src/modules/template.rs index fe0c2c6..7476132 100644 --- a/src/modules/template.rs +++ b/src/modules/template.rs @@ -40,7 +40,7 @@ pub fn template(lua: &Lua, params: Table) -> mlua::Result
{ .take(10) .collect(); - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "template" }) diff --git a/src/modules/upload.rs b/src/modules/upload.rs index 1e0c3a3..7ab46be 100644 --- a/src/modules/upload.rs +++ b/src/modules/upload.rs @@ -1,7 +1,7 @@ use mlua::{chunk, ExternalResult, Lua, Table}; pub fn upload(lua: &Lua, params: Table) -> mlua::Result
{ - let base_module = super::base_module(&lua); + let base_module = super::base_module(lua)?; let module = lua .load(chunk! { local module = $base_module:new({ name = "upload" }) diff --git a/src/ssh.rs b/src/ssh.rs index c398aa7..0f61e72 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -6,7 +6,7 @@ use std::{ path::Path, }; -use anyhow::Result; +use anyhow::{Error, Result}; use mlua::UserData; use ssh2::{CheckResult, KnownHostFileKind, Session, Sftp}; @@ -73,42 +73,39 @@ impl SSHSession { self.session.set_tcp_stream(tcp); self.session.handshake()?; - match &self.known_hosts_file { - Some(file) => { - let host_key = self.session.host_key().unwrap(); - let mut known_hosts = self.session.known_hosts()?; - match known_hosts.read_file(Path::new(file.as_str()), KnownHostFileKind::OpenSSH) { - Ok(_) => {} - Err(_) => { - return Err(anyhow::Error::msg( - format!("SSH host key verification failed. Please add the host key to the known_hosts file: {}", file), - )); - } - }; - - let known_hosts_check_result = known_hosts.check(&address, &host_key.0); - match known_hosts_check_result { - CheckResult::Match => {} - _ => { - return Err(anyhow::Error::msg( - format!("SSH host key verification failed ({:?}). Please check the known_hosts file: {}", known_hosts_check_result, file), - )); - } - }; - } - None => {} + if let Some(file) = &self.known_hosts_file { + let host_key = self.session.host_key().unwrap(); + let mut known_hosts = self.session.known_hosts()?; + match known_hosts.read_file(Path::new(file.as_str()), KnownHostFileKind::OpenSSH) { + Ok(_) => {} + Err(_) => { + return Err(Error::msg( + format!("SSH host key verification failed. Please add the host key to the known_hosts file: {}", file) + )); + } + }; + + let known_hosts_check_result = known_hosts.check(address, host_key.0); + match known_hosts_check_result { + CheckResult::Match => {} + _ => { + return Err(Error::msg( + format!("SSH host key verification failed ({:?}). Please check the known_hosts file: {}", known_hosts_check_result, file) + )); + } + }; } match auth_method { SSHAuthMethod::Password(password) => { - self.session.userauth_password(&username, &password)?; + self.session.userauth_password(username, &password)?; } SSHAuthMethod::PublicKey { private_key, passphrase, } => { self.session.userauth_pubkey_file( - &username, + username, None, Path::new(&private_key), passphrase.as_deref(), @@ -117,7 +114,7 @@ impl SSHSession { } if !self.session.authenticated() { - return Err(anyhow::Error::msg("SSH authentication failed.")); + return Err(Error::msg("SSH authentication failed.")); } Ok(()) @@ -240,7 +237,7 @@ impl SSHSession { let mut remote_file = self.session .scp_send(Path::new(remote_path), 0o644, content_length, None)?; - remote_file.write(content)?; + remote_file.write_all(content)?; remote_file.send_eof()?; remote_file.wait_eof()?; remote_file.close()?; @@ -260,7 +257,7 @@ fn upload_file(sftp: &mut Sftp, local_path: &Path, remote_path: &Path) -> io::Re } fn upload_directory(sftp: &mut Sftp, local_path: &Path, remote_path: &Path) -> io::Result<()> { - if !sftp.stat(remote_path).is_ok() { + if sftp.stat(remote_path).is_err() { sftp.mkdir(remote_path, 0o755)?; } diff --git a/src/util.rs b/src/util.rs index 5f7bc2a..d06ef8a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -116,10 +116,10 @@ pub fn parse_hosts_json_file(lua: &Lua, path: Value) -> mlua::Result
{ Ok(_) => match parse_hosts_json(lua, content) { Ok(h) => h, Err(_) => { - return Err(RuntimeError(String::from(format!( + return Err(RuntimeError(format!( "Failed to parse JSON file from '{}'", path - )))); + ))); } }, Err(_) => return Err(RuntimeError(String::from("Failed to read JSON file"))), @@ -158,20 +158,14 @@ pub fn parse_hosts_json_url(lua: &Lua, url: Value) -> mlua::Result
{ String::from_utf8_lossy(&response.body).to_string() } Err(e) => { - return Err(RuntimeError(String::from(format!( - "Failed to fetch URL: {:?}", - e - )))); + return Err(RuntimeError(format!("Failed to fetch URL: {:?}", e))); } }; let hosts = match parse_hosts_json(lua, content) { Ok(h) => h, Err(_) => { - return Err(RuntimeError(String::from(format!( - "Failed to parse JSON from '{}'", - url - )))); + return Err(RuntimeError(format!("Failed to parse JSON from '{}'", url))); } }; @@ -209,12 +203,9 @@ fn parse_hosts_json(lua: &Lua, content: String) -> mlua::Result
{ for pair in lua_table.pairs() { let (_, value): (Value, Value) = pair?; - match validate_host(&lua, value) { - Ok(host) => { - hosts.set(hosts.len()? + 1, host)?; - } - Err(_) => {} - }; + if let Ok(host) = validate_host(lua, value) { + hosts.set(hosts.len()? + 1, host)?; + } } Ok(hosts) @@ -236,7 +227,7 @@ pub fn host_display(host: &Table) -> String { match host.get::("name") { Ok(name) => format!("{} ({})", name, address), - Err(_) => format!("{}", address), + Err(_) => address, } } diff --git a/src/validator.rs b/src/validator.rs index a8e96f1..823273d 100644 --- a/src/validator.rs +++ b/src/validator.rs @@ -2,15 +2,15 @@ use mlua::{chunk, Error::RuntimeError, ExternalResult, Integer, Lua, Table, Valu pub fn validate_host(lua: &Lua, host: Value) -> mlua::Result
{ if !host.is_table() { - return Err(RuntimeError(format!("Host is not a table."))); + return Err(RuntimeError("Host is not a table.".to_string())); } let address = host.as_table().unwrap().get::("address")?; if address.is_nil() { - return Err(RuntimeError(format!("Host address is empty."))); + return Err(RuntimeError("Host address is empty.".to_string())); } if !address.is_string() { - return Err(RuntimeError(format!("Host address is invalid."))); + return Err(RuntimeError("Host address is invalid.".to_string())); } let port = host.as_table().unwrap().get::("port")?; @@ -245,7 +245,7 @@ mod tests { let result = super::validate_module(&lua, mlua::Value::String(lua.create_string("ls").unwrap())); - eprint!("result: {:#?}\n", result.clone().err()); + eprintln!("result: {:#?}", result.clone().err()); assert!(result.is_ok()); }