From 3120a9ca6b1cfbb5aa15793e1a1f12c2b77a8824 Mon Sep 17 00:00:00 2001 From: Abdul Munif Hanafi Date: Wed, 25 Dec 2024 22:38:13 +0700 Subject: [PATCH] Change defaults object from table to userdata --- .github/workflows/coverage.yml | 2 +- .github/workflows/rust.yml | 2 +- README.md | 27 +- examples/multi_hosts_tasks.lua | 6 +- examples/parse_hosts_json.lua | 6 +- http-client/src/lib.rs | 187 ++++++++++++ src/defaults.rs | 444 ++++++++++++++++++++++++++-- src/lib.rs | 511 +++++++++++++++++++++++---------- src/ssh.rs | 20 +- src/util.rs | 61 ---- tests/komando.rs | 4 +- 11 files changed, 1010 insertions(+), 260 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index a4c1ba7..0428f88 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -30,7 +30,7 @@ jobs: - name: Generate coverage report run: | - cargo +nightly tarpaulin --verbose --out Xml --implicit-test-threads + cargo +nightly tarpaulin --verbose --out Xml --implicit-test-threads --workspace - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d5a187b..15601a1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -46,4 +46,4 @@ jobs: echo "127.0.0.1 localhost2" | sudo tee -a /etc/hosts > /dev/null - name: Run tests - run: cargo test --verbose + run: cargo test --verbose --workspace diff --git a/README.md b/README.md index 88b66d5..4e9a37b 100644 --- a/README.md +++ b/README.md @@ -139,11 +139,34 @@ For detailed explanations, arguments, and examples of each module, please refer Komandan offers built-in functions to enhance scripting capabilities: - **`komandan.filter_hosts`**: Filters a list of hosts based on a pattern. -- **`komandan.parse_hosts_json`**: Parses a JSON file containing hosts information. -- **`komandan.set_defaults`**: Sets default values for host connection parameters. +- **`komandan.parse_hosts_json_file`**: Parses a JSON file containing hosts information. +- **`komandan.parse_hosts_json_url`**: Parses a JSON file from a URL containing hosts information. For detailed descriptions and usage examples of these functions, please visit the [Built-in Functions section of the Komandan Documentation Site](https://komandan.vercel.app/docs/functions/). +## Default Values + +Komandan provides default values for various parameters, such as the user, private key file path, and SSH port. These values can be set using the `komandan.defaults` userdata. + +```lua +-- set default values +komandan.defaults:set_port(22) +komandan.defaults:set_user("user1") +komandan.defaults:set_private_key_file(os.getenv("HOME") .. "/.ssh/id_ed25519") +komandan.defaults:set_private_key_pass("passphrase") +komandan.defaults:set_host_key_check(false) +komandan.defaults:set_env("ENV_VAR", "value") + +-- get default values +local port = komandan.defaults:get_port() +local user = komandan.defaults:get_user() +local private_key_file = komandan.defaults:get_private_key_file() +local private_key_pass = komandan.defaults:get_private_key_pass() +local host_key_check = komandan.defaults:get_host_key_check() +local env = komandan.defaults:get_env("ENV_VAR") +local env_all = komandan.defaults:get_all_env() +``` + ## Error Handling Komandan provides error information through the return values of the `komando` function. If a task fails, the `exit_code` will be non-zero, and `stderr` may contain error messages. You can use the `ignore_exit_code` option in a task to continue execution even if a task fails. diff --git a/examples/multi_hosts_tasks.lua b/examples/multi_hosts_tasks.lua index 3708594..aab7be3 100644 --- a/examples/multi_hosts_tasks.lua +++ b/examples/multi_hosts_tasks.lua @@ -1,9 +1,7 @@ local hosts = require("hosts") -komandan.set_defaults({ - user = "user1", - private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519", -}) +komandan.defaults:set_user("user1") +komandan.defaults:set_private_key_file(os.getenv("HOME") .. "/.ssh/id_ed25519") local tasks = { { diff --git a/examples/parse_hosts_json.lua b/examples/parse_hosts_json.lua index cbb1d4a..a220320 100644 --- a/examples/parse_hosts_json.lua +++ b/examples/parse_hosts_json.lua @@ -2,10 +2,8 @@ local hosts = komandan.parse_hosts_json_file("/path/to/hosts.json") -- or use a URL -- local hosts = komandan.parse_hosts_json_url("http://localhost:8000/hosts.json") -komandan.set_defaults({ - user = "user1", - private_key_file = os.getenv("HOME") .. "/.ssh/id_ed25519", -}) +komandan.defaults:set_user("user1") +komandan.defaults:set_private_key_file(os.getenv("HOME") .. "/.ssh/id_ed25519") for _, host in pairs(hosts) do komandan.komando(host, { diff --git a/http-client/src/lib.rs b/http-client/src/lib.rs index d7ead2e..fbe0f41 100644 --- a/http-client/src/lib.rs +++ b/http-client/src/lib.rs @@ -607,3 +607,190 @@ pub fn create_client_from_url(url: &str) -> Result<(HttpClient, String), Box { + assert_eq!(configured_proxy.host, "proxy.example.com"); + assert_eq!(configured_proxy.port, 8080); + assert_eq!( + configured_proxy.auth, + Some(("username".to_string(), "password".to_string())) + ); + assert_eq!(configured_proxy.use_https, false); + } + None => panic!("Proxy should be configured"), + } + } + + #[test] + fn test_http_method_to_string() { + assert_eq!(HttpMethod::GET.to_string(), "GET"); + assert_eq!(HttpMethod::POST.to_string(), "POST"); + assert_eq!(HttpMethod::PUT.to_string(), "PUT"); + assert_eq!(HttpMethod::DELETE.to_string(), "DELETE"); + assert_eq!(HttpMethod::PATCH.to_string(), "PATCH"); + assert_eq!(HttpMethod::HEAD.to_string(), "HEAD"); + assert_eq!(HttpMethod::CONNECT.to_string(), "CONNECT"); + } + + #[test] + fn test_http_client_initialization() { + let client = HttpClient::new("https://api.example.com"); + assert_eq!(client.host, "https://api.example.com"); + assert!(client.auth.is_none()); + assert!(client.headers.is_empty()); + assert!(client.timeout.is_none()); + assert_eq!(client.max_redirects, 5); + assert!(client.proxy.is_none()); + assert!(client.verify_ssl); + assert!(client.enable_ipv6); + } + + #[test] + fn test_http_client_configuration() { + let mut client = HttpClient::new("https://api.example.com"); + + client.set_auth("username", "password"); + assert_eq!( + client.auth, + Some(("username".to_string(), "password".to_string())) + ); + + client.set_header("User-Agent", "Test Client"); + assert_eq!( + client.headers.get("User-Agent"), + Some(&"Test Client".to_string()) + ); + + let timeout = Duration::from_secs(30); + client.set_timeout(timeout); + assert_eq!(client.timeout, Some(timeout)); + + client.set_max_redirects(3); + assert_eq!(client.max_redirects, 3); + + client.set_verify_ssl(false); + assert!(!client.verify_ssl); + + client.set_enable_ipv6(false); + assert!(!client.enable_ipv6); + } + + #[test] + fn test_http_response_status_checks() { + let success_response = HttpResponse { + status_code: 200, + headers: HashMap::new(), + body: Vec::new(), + content_type: None, + }; + assert!(success_response.is_success()); + assert!(!success_response.is_error()); + + let client_error_response = HttpResponse { + status_code: 404, + headers: HashMap::new(), + body: Vec::new(), + content_type: None, + }; + assert!(client_error_response.is_client_error()); + assert!(client_error_response.is_error()); + + let server_error_response = HttpResponse { + status_code: 500, + headers: HashMap::new(), + body: Vec::new(), + content_type: None, + }; + assert!(server_error_response.is_server_error()); + assert!(server_error_response.is_error()); + } + + #[test] + fn test_parse_url() -> Result<(), Box> { + let test_cases = vec![ + ("https://example.com/path?query=value", ParsedUrl { + scheme: "https".to_string(), + host: "example.com".to_string(), + path: "/path".to_string(), + query: Some("query=value".to_string()), + }), + ("http://example.com", ParsedUrl { + scheme: "http".to_string(), + host: "example.com".to_string(), + path: "/".to_string(), + query: None, + }), + ("https://api.example.com/v1/users/", ParsedUrl { + scheme: "https".to_string(), + host: "api.example.com".to_string(), + path: "/v1/users/".to_string(), + query: None, + }), + ]; + + for (input, expected) in test_cases { + let parsed = parse_url(input)?; + assert_eq!(parsed.scheme, expected.scheme); + assert_eq!(parsed.host, expected.host); + assert_eq!(parsed.path, expected.path); + assert_eq!(parsed.query, expected.query); + } + + Ok(()) + } + + #[test] + fn test_create_client_from_url() -> Result<(), Box> { + let (client, path) = + create_client_from_url("https://api.example.com/v1/users?active=true")?; + + assert_eq!(client.host, "https://api.example.com"); + assert_eq!(path, "/v1/users?active=true"); + + let (client2, path2) = create_client_from_url("http://example.com")?; + assert_eq!(client2.host, "http://example.com"); + assert_eq!(path2, "/"); + + Ok(()) + } + + #[test] + fn test_invalid_url() { + let result = parse_url("invalid-url"); + assert!(result.is_err()); + + let result = create_client_from_url("invalid-url"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_response() { + let client = HttpClient::new("https://example.com"); + let raw_response = "HTTP/1.1 200 OK\r\n\ + Content-Type: application/json\r\n\ + Content-Length: 2\r\n\ + \r\n\ + {}"; + + let response = client.parse_response(raw_response).unwrap(); + assert_eq!(response.status_code, 200); + assert_eq!(response.content_type, Some("application/json".to_string())); + assert_eq!(response.body, "{}".as_bytes()); + } +} diff --git a/src/defaults.rs b/src/defaults.rs index ed38627..1fbdda6 100644 --- a/src/defaults.rs +++ b/src/defaults.rs @@ -1,21 +1,425 @@ -use mlua::{Lua, Table}; - -pub fn defaults(lua: &Lua) -> mlua::Result { - let defaults = lua.create_table()?; - - defaults.set("port", 22)?; - defaults.set("ignore_exit_code", false)?; - defaults.set("elevate", false)?; - defaults.set("elevation_method", "sudo")?; - defaults.set( - "known_hosts_file", - format!("{}/.ssh/known_hosts", env!("HOME")), - )?; - defaults.set("host_key_check", true)?; - - let env = lua.create_table()?; - env.set("DEBIAN_FRONTEND", "noninteractive")?; - defaults.set("env", env)?; - - Ok(defaults) +use mlua::UserData; +use std::{ + collections::HashMap, + sync::{Arc, Mutex, OnceLock}, +}; + +static GLOBAL_STATE: OnceLock = OnceLock::new(); + +#[derive(Clone)] +pub struct Defaults { + pub port: Arc>, + pub user: Arc>>, + pub private_key_file: Arc>>, + pub private_key_pass: Arc>>, + pub password: Arc>>, + pub ignore_exit_code: Arc>, + pub elevate: Arc>, + pub elevation_method: Arc>, + pub as_user: Arc>>, + pub known_hosts_file: Arc>, + pub host_key_check: Arc>, + pub env: Arc>>, +} + +impl Defaults { + pub fn new() -> Self { + let env = Arc::new(Mutex::new(HashMap::new())); + match env.lock() { + Ok(mut env) => { + env.insert("DEBIAN_FRONTEND".to_string(), "noninteractive".to_string()); + } + Err(_) => {} + } + + Self { + port: Arc::new(Mutex::new(22)), + user: Arc::new(Mutex::new(None)), + private_key_file: Arc::new(Mutex::new(None)), + private_key_pass: Arc::new(Mutex::new(None)), + password: Arc::new(Mutex::new(None)), + ignore_exit_code: Arc::new(Mutex::new(false)), + elevate: Arc::new(Mutex::new(false)), + elevation_method: Arc::new(Mutex::new("sudo".to_string())), + as_user: Arc::new(Mutex::new(None)), + known_hosts_file: Arc::new(Mutex::new(format!("{}/.ssh/known_hosts", env!("HOME")))), + host_key_check: Arc::new(Mutex::new(true)), + env, + } + } + + pub fn global() -> Self { + GLOBAL_STATE.get_or_init(|| Defaults::new()).clone() + } +} + +impl UserData for Defaults { + fn add_methods>(methods: &mut M) { + methods.add_method("get_port", |_, this, ()| -> mlua::Result { + match this.port.lock() { + Ok(port) => Ok(*port), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }); + + methods.add_method_mut("set_port", |_, this, new_port: u16| -> mlua::Result<()> { + match this.port.lock() { + Ok(mut port) => { + *port = new_port; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }); + + methods.add_method("get_user", |_, this, ()| -> mlua::Result> { + match this.user.lock() { + Ok(user) => Ok(user.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }); + + methods.add_method_mut( + "set_user", + |_, this, new_user: Option| -> mlua::Result<()> { + match this.user.lock() { + Ok(mut user) => { + *user = new_user; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_private_key_file", + |_, this, ()| -> mlua::Result> { + match this.private_key_file.lock() { + Ok(private_key_file) => Ok(private_key_file.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_private_key_file", + |_, this, new_private_key_file: Option| -> mlua::Result<()> { + match this.private_key_file.lock() { + Ok(mut private_key_file) => { + *private_key_file = new_private_key_file; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_private_key_pass", + |_, this, ()| -> mlua::Result> { + match this.private_key_pass.lock() { + Ok(private_key_pass) => Ok(private_key_pass.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_private_key_pass", + |_, this, new_private_key_pass: Option| -> mlua::Result<()> { + match this.private_key_pass.lock() { + Ok(mut private_key_pass) => { + *private_key_pass = new_private_key_pass; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_password", + |_, this, ()| -> mlua::Result> { + match this.password.lock() { + Ok(password) => Ok(password.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_password", + |_, this, new_password: Option| -> mlua::Result<()> { + match this.password.lock() { + Ok(mut password) => { + *password = new_password; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_ignore_exit_code", + |_, this, ()| -> mlua::Result { + match this.ignore_exit_code.lock() { + Ok(ignore_exit_code) => Ok(*ignore_exit_code), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_ignore_exit_code", + |_, this, new_ignore_exit_code: bool| -> mlua::Result<()> { + match this.ignore_exit_code.lock() { + Ok(mut ignore_exit_code) => { + *ignore_exit_code = new_ignore_exit_code; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method("get_elevate", |_, this, ()| -> mlua::Result { + match this.elevate.lock() { + Ok(elevate) => Ok(*elevate), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }); + + methods.add_method_mut( + "set_elevate", + |_, this, new_elevate: bool| -> mlua::Result<()> { + match this.elevate.lock() { + Ok(mut elevate) => { + *elevate = new_elevate; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_elevation_method", + |_, this, ()| -> mlua::Result { + match this.elevation_method.lock() { + Ok(elevation_method) => Ok(elevation_method.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_elevation_method", + |_, this, new_elevation_method: String| -> mlua::Result<()> { + match this.elevation_method.lock() { + Ok(mut elevation_method) => { + *elevation_method = new_elevation_method; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_as_user", + |_, this, ()| -> mlua::Result> { + match this.as_user.lock() { + Ok(as_user) => Ok(as_user.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_as_user", + |_, this, new_as_user: Option| -> mlua::Result<()> { + match this.as_user.lock() { + Ok(mut as_user) => { + *as_user = new_as_user; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method( + "get_known_hosts_file", + |_, this, ()| -> mlua::Result { + match this.known_hosts_file.lock() { + Ok(known_hosts_file) => Ok(known_hosts_file.clone()), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method_mut( + "set_known_hosts_file", + |_, this, new_known_hosts_file: String| -> mlua::Result<()> { + match this.known_hosts_file.lock() { + Ok(mut known_hosts_file) => { + *known_hosts_file = new_known_hosts_file; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method("get_host_key_check", |_, this, ()| -> mlua::Result { + match this.host_key_check.lock() { + Ok(host_key_check) => Ok(*host_key_check), + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }); + + methods.add_method_mut( + "set_host_key_check", + |_, this, new_host_key_check: bool| -> mlua::Result<()> { + match this.host_key_check.lock() { + Ok(mut host_key_check) => { + *host_key_check = new_host_key_check; + Ok(()) + } + Err(_) => { + return Err(mlua::Error::RuntimeError( + "Failed to acquire lock".to_string(), + )) + } + } + }, + ); + + methods.add_method("get_all_env", |lua, this, ()| match this.env.lock() { + Ok(map) => { + let keys: Vec = map.keys().cloned().collect(); + lua.create_table_from(keys.into_iter().enumerate()) + } + Err(_) => Err(mlua::Error::runtime("Failed to acquire lock")), + }); + + methods.add_method_mut("get_env", |_, this, key: String| -> mlua::Result { + match this.env.lock() { + Ok(map) => match map.get(&key) { + Some(value) => Ok(value.clone()), + None => Ok(String::new()), + }, + Err(_) => Err(mlua::Error::runtime("Failed to acquire lock")), + } + }); + + methods.add_method_mut( + "set_env", + |_, this, (key, value): (String, String)| -> mlua::Result<()> { + match this.env.lock() { + Ok(mut map) => { + map.insert(key, value); + Ok(()) + } + Err(_) => Err(mlua::Error::runtime("Failed to acquire lock")), + } + }, + ); + + methods.add_method_mut("remove_env", |_, this, key: String| -> mlua::Result<()> { + match this.env.lock() { + Ok(mut map) => { + map.remove(&key); + Ok(()) + } + Err(_) => Err(mlua::Error::runtime("Failed to acquire lock")), + } + }); + } } diff --git a/src/lib.rs b/src/lib.rs index 7b698b1..0d934ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,14 +7,15 @@ mod validator; use args::Args; use clap::Parser; +use defaults::Defaults; use mlua::{chunk, Error::RuntimeError, Integer, Lua, MultiValue, Table, Value}; use modules::{apt, base_module, cmd, download, lineinfile, script, template, upload}; use rustyline::DefaultEditor; -use ssh::{ElevateMethod, Elevation, SSHAuthMethod, SSHSession}; +use ssh::{Elevation, ElevationMethod, SSHAuthMethod, SSHSession}; use std::{env, fs, path::Path}; use util::{ dprint, filter_hosts, host_display, parse_hosts_json_file, parse_hosts_json_url, - regex_is_match, set_defaults, task_display, + regex_is_match, task_display, }; use validator::{validate_host, validate_task}; @@ -55,12 +56,12 @@ pub fn create_lua() -> mlua::Result { pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { let komandan = lua.create_table()?; - komandan.set("defaults", defaults::defaults(&lua)?)?; + let defaults = Defaults::global(); + komandan.set("defaults", defaults)?; let base_module = base_module(&lua); komandan.set("KomandanModule", base_module)?; - komandan.set("set_defaults", lua.create_function(set_defaults)?)?; komandan.set("komando", lua.create_function(komando)?)?; // Add utils @@ -92,62 +93,81 @@ pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> { Ok(()) } -fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ - let host = lua.create_function(validate_host)?.call::
(&host)?; - let task = lua.create_function(validate_task)?.call::
(&task)?; - let module = task.get::
(1)?; - - let host_display = host_display(&host); - let task_display = task_display(&task); - - let defaults = lua - .globals() - .get::
("komandan")? - .get::
("defaults")?; - +fn get_user(host: &Table, task: &Table) -> mlua::Result { + let defaults = Defaults::global(); + let default_user = match defaults.user.lock() { + Ok(user) => user, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; let user = match host.get::("user") { Ok(user) => user, - Err(_) => match defaults.get::("user") { - Ok(user) => user, - Err(_) => match env::var("USER") { + Err(_) => match *default_user { + Some(ref user) => user.clone(), + None => match env::var("USER") { Ok(user) => user, Err(_) => { return Err(RuntimeError(format!( "No user specified for task '{}'.", - task_display + task_display(task) ))) } }, }, }; + Ok(user) +} + +fn get_auth_config(host: &Table, task: &Table) -> mlua::Result<(String, SSHAuthMethod)> { + let host_display = host_display(host); + let task_display = task_display(task); + + let user = get_user(host, task)?; + + let defaults = Defaults::global(); + + let default_private_key_file = match defaults.private_key_file.lock() { + Ok(private_key_file) => private_key_file, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; + + let default_private_key_pass = match defaults.private_key_pass.lock() { + Ok(private_key_pass) => private_key_pass, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; + + let default_password = match defaults.password.lock() { + Ok(password) => password, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; + let ssh_auth_method = match host.get::("private_key_file") { Ok(private_key_file) => SSHAuthMethod::PublicKey { private_key: private_key_file, passphrase: match host.get::("private_key_pass") { Ok(passphrase) => Some(passphrase), - Err(_) => match defaults.get::("private_key_pass") { - Ok(passphrase) => Some(passphrase), - Err(_) => None, + Err(_) => match *default_private_key_pass { + Some(ref private_key_pass) => Some(private_key_pass.clone()), + None => None, }, }, }, - Err(_) => match defaults.get::("private_key_file") { - Ok(private_key_file) => SSHAuthMethod::PublicKey { - private_key: private_key_file, + Err(_) => match *default_private_key_file { + Some(ref private_key_file) => SSHAuthMethod::PublicKey { + private_key: private_key_file.clone(), passphrase: match host.get::("private_key_pass") { Ok(passphrase) => Some(passphrase), - Err(_) => match defaults.get::("private_key_pass") { - Ok(passphrase) => Some(passphrase), - Err(_) => None, + Err(_) => match *default_private_key_pass { + Some(ref passphrase) => Some(passphrase.clone()), + None => None, }, }, }, - Err(_) => match host.get::("password") { + None => match host.get::("password") { Ok(password) => SSHAuthMethod::Password(password), - Err(_) => match defaults.get::("password") { - Ok(password) => SSHAuthMethod::Password(password), - Err(_) => { + Err(_) => match *default_password { + Some(ref password) => SSHAuthMethod::Password(password.clone()), + None => { return Err(RuntimeError(format!( "No authentication method specified for task '{}' on host '{}'.", task_display, host_display @@ -158,110 +178,181 @@ fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ }, }; - let port = host.get::("port").unwrap_or_else(|_| { - defaults - .get::("port") - .unwrap_or_else(|_| 22.into()) - }) as u16; - - let elevate = match task.get::("elevate") { - Ok(elevate) => match elevate { - Value::Nil => match host.get::("elevate") { - Ok(elevate) => match elevate { - Value::Nil => match defaults.get::("elevate") { - Ok(elevate) => match elevate { - Value::Nil => false, - Value::Boolean(true) => true, - _ => false, - }, - Err(_) => false, - }, - Value::Boolean(true) => true, - _ => false, - }, - Err(_) => false, - }, - Value::Boolean(true) => true, - _ => false, - }, - Err(_) => false, + Ok((user, ssh_auth_method)) +} + +fn get_elevation_config(host: &Table, task: &Table) -> mlua::Result { + let defaults = Defaults::global(); + + let default_elevate = match defaults.elevate.lock() { + Ok(elevate) => elevate, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), }; - let as_user: Option = match task.get::("as_user") { - Ok(as_user) => Some(as_user), - Err(_) => match host.get::("as_user") { - Ok(as_user) => Some(as_user), - Err(_) => match defaults.get::("as_user") { - Ok(as_user) => Some(as_user), - Err(_) => None, - }, - }, + let elevate = task + .get::("elevate") + .unwrap_or(host.get::("elevate").unwrap_or(*default_elevate)); + + if !elevate { + return Ok(Elevation { + method: ElevationMethod::None, + as_user: None, + }); + } + + let default_elevation_method = match defaults.elevation_method.lock() { + Ok(elevation_method) => elevation_method, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), }; - let elevation_method_str = match elevate { - true => match user.as_str() { - "root" => "su".to_string(), - _ => match task.get::("elevation_method") { - Ok(method) => method, - Err(_) => match host.get::("elevation_method") { - Ok(method) => method, - Err(_) => match defaults.get::("elevation_method") { - Ok(method) => method, - Err(_) => "none".to_string(), - }, - }, - }, - }, - false => "none".to_string(), + let elevation_method_str = task.get::("elevation_method").unwrap_or( + host.get::("elevation_method") + .unwrap_or(default_elevation_method.clone()), + ); + + let elevation_method = match elevation_method_str.as_str() { + "none" => Ok(ElevationMethod::None), + "sudo" => Ok(ElevationMethod::Sudo), + "su" => Ok(ElevationMethod::Su), + _ => Err(RuntimeError(format!( + "Unsupported elevation method: {}", + elevation_method_str + ))), }; - let elevation_method = match elevation_method_str.to_lowercase().as_str() { - "none" => ElevateMethod::None, - "su" => ElevateMethod::Su, - "sudo" => ElevateMethod::Sudo, - _ => { - return Err(RuntimeError(format!( - "Invalid elevation_method '{}' for task '{}' on host '{}'.", - elevation_method_str, task_display, host_display - ))) - } + let default_as_user = match defaults.as_user.lock() { + Ok(as_user) => as_user, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), }; - let elevation = Elevation { - method: elevation_method, + let as_user = task.get::>("as_user").unwrap_or( + host.get::>("as_user") + .unwrap_or(default_as_user.clone()), + ); + + Ok(Elevation { + method: elevation_method?, as_user, - }; + }) +} - let known_hosts_file = match host.get::("known_hosts_file") { - Ok(known_hosts_file) => Some(known_hosts_file), - Err(_) => match defaults.get::("known_hosts_file") { - Ok(known_hosts_file) => Some(known_hosts_file), - Err(_) => None, - }, +fn setup_ssh_session(host: &Table) -> mlua::Result { + let defaults = Defaults::global(); + let mut ssh = SSHSession::new()?; + + let default_host_key_check = match defaults.host_key_check.lock() { + Ok(host_key_check) => host_key_check, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), }; let host_key_check = match host.get::("host_key_check") { Ok(host_key_check) => match host_key_check { - Value::Nil => match defaults.get::("host_key_check") { - Ok(host_key_check) => match host_key_check { - Value::Nil => true, - Value::Boolean(false) => false, - _ => true, - }, - Err(_) => true, - }, + Value::Nil => default_host_key_check.clone(), Value::Boolean(false) => false, _ => true, }, Err(_) => true, }; - let mut ssh = SSHSession::new()?; + let default_known_hosts_file = match defaults.known_hosts_file.lock() { + Ok(known_hosts_file) => known_hosts_file, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; if host_key_check { - ssh.known_hosts_file = known_hosts_file; + ssh.known_hosts_file = match host.get::("known_hosts_file") { + Ok(known_hosts_file) => Some(known_hosts_file), + Err(_) => Some(default_known_hosts_file.clone()), + }; } + Ok(ssh) +} + +fn setup_environment(ssh: &mut SSHSession, host: &Table, task: &Table) -> mlua::Result<()> { + let defaults = Defaults::global(); + + let default_env = match defaults.env.lock() { + Ok(env) => env, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; + + let env_host = host.get::>("env")?; + let env_task = task.get::>("env")?; + + for (key, value) in default_env.clone() { + ssh.set_env(&key, &value); + } + + if !env_host.is_none() { + for pair in env_host.unwrap().pairs() { + let (key, value): (String, String) = pair?; + ssh.set_env(&key, &value); + } + } + + if !env_task.is_none() { + for pair in env_task.unwrap().pairs() { + let (key, value): (String, String) = pair?; + ssh.set_env(&key, &value); + } + } + + Ok(()) +} + +fn execute_task( + lua: &Lua, + module: &Table, + ssh: SSHSession, + task_display: &str, + host_display: &str, +) -> mlua::Result
{ + lua.load(chunk! { + print("Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...") + $module.ssh = $ssh + $module:run() + + local results = $module.ssh:get_session_results() + komandan.dprint(results.stdout) + if results.exit_code ~= 0 then + print("Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. results.exit_code .. ": " .. results.stderr) + else + print("Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded.") + end + + if $module.cleanup ~= nil then + $module:cleanup() + end + + return results + }) + .eval::
() +} + +fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ + let host = lua.create_function(validate_host)?.call::
(&host)?; + let task = lua.create_function(validate_task)?.call::
(&task)?; + let module = task.get::
(1)?; + + let host_display = host_display(&host); + let task_display = task_display(&task); + + let defaults = Defaults::global(); + + let (user, ssh_auth_method) = get_auth_config(&host, &task)?; + let elevation = get_elevation_config(&host, &task)?; + + let default_port = match defaults.port.lock() { + Ok(port) => port, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; + + let port = host + .get::("port") + .unwrap_or(default_port.clone() as i64) as u16; + + let mut ssh = setup_ssh_session(&host)?; ssh.elevation = elevation; ssh.connect( @@ -271,50 +362,18 @@ fn komando(lua: &Lua, (host, task): (Value, Value)) -> mlua::Result
{ ssh_auth_method, )?; - let env_defaults = defaults.get::
("env").unwrap_or(lua.create_table()?); - let env_host = host.get::
("env").unwrap_or(lua.create_table()?); - let env_task = task.get::
("env").unwrap_or(lua.create_table()?); + setup_environment(&mut ssh, &host, &task)?; - for pair in env_defaults.pairs() { - let (key, value): (String, String) = pair?; - ssh.set_env(&key, &value); - } - - for pair in env_host.pairs() { - let (key, value): (String, String) = pair?; - ssh.set_env(&key, &value); - } + let results = execute_task(lua, &module, ssh, &task_display, &host_display)?; - for pair in env_task.pairs() { - let (key, value): (String, String) = pair?; - ssh.set_env(&key, &value); - } - - let results = lua - .load(chunk! { - print("Running task '" .. $task_display .. "' on host '" .. $host_display .."' ...") - $module.ssh = $ssh - $module:run() - - local results = $module.ssh:get_session_results() - komandan.dprint(results.stdout) - if results.exit_code ~= 0 then - print("Task '" .. $task_display .. "' on host '" .. $host_display .."' failed with exit code " .. results.exit_code .. ": " .. results.stderr) - else - print("Task '" .. $task_display .. "' on host '" .. $host_display .."' succeeded.") - end - - if $module.cleanup ~= nil then - $module:cleanup() - end - - return results - }) - .eval::
()?; + let default_ignore_exit_code = match defaults.ignore_exit_code.lock() { + Ok(ignore_exit_code) => ignore_exit_code, + Err(_) => return Err(RuntimeError("Failed to acquire lock".to_string())), + }; let ignore_exit_code = task .get::("ignore_exit_code") - .unwrap_or_else(|_| defaults.get::("ignore_exit_code").unwrap()); + .unwrap_or(default_ignore_exit_code.clone()); if results.get::("exit_code").unwrap() != 0 && !ignore_exit_code { return Err(RuntimeError("Failed to run task.".to_string())); @@ -419,7 +478,6 @@ mod tests { let komandan_table = lua.globals().get::
("komandan").unwrap(); assert!(komandan_table.contains_key("defaults").unwrap()); assert!(komandan_table.contains_key("KomandanModule").unwrap()); - assert!(komandan_table.contains_key("set_defaults").unwrap()); assert!(komandan_table.contains_key("komando").unwrap()); assert!(komandan_table.contains_key("regex_is_match").unwrap()); assert!(komandan_table.contains_key("filter_hosts").unwrap()); @@ -448,4 +506,149 @@ mod tests { let result = run_main_file(&lua, &main_file); assert!(result.is_ok()); } + + #[test] + fn test_get_auth_config() { + let lua = create_lua().unwrap(); + let host = lua.create_table().unwrap(); + + // Test with user in host + host.set("address", "localhost").unwrap(); + host.set("user", "testuser").unwrap(); + host.set("private_key_file", "/path/to/key").unwrap(); + + let module_params = lua.create_table().unwrap(); + module_params.set("cmd", "echo test").unwrap(); + let module = cmd(&lua, module_params).unwrap(); + let task = lua.create_table().unwrap(); + task.set(1, module).unwrap(); + + let (user, auth) = get_auth_config(&host, &task).unwrap(); + assert_eq!(user, "testuser"); + match auth { + SSHAuthMethod::PublicKey { + private_key, + passphrase, + } => { + assert_eq!(private_key, "/path/to/key"); + assert!(passphrase.is_none()); + } + _ => panic!("Expected PublicKey authentication"), + } + + // Test with password auth + host.set("private_key_file", Value::Nil).unwrap(); + host.set("password", "testpass").unwrap(); + let (_, auth) = get_auth_config(&host, &task).unwrap(); + match auth { + SSHAuthMethod::Password(pass) => assert_eq!(pass, "testpass"), + _ => panic!("Expected Password authentication"), + } + + // Test with no authentication method + host.set("password", Value::Nil).unwrap(); + let result = get_auth_config(&host, &task); + assert!(result.is_err()); + } + + #[test] + fn test_get_elevation_config() { + let lua = Lua::new(); + let host = lua.create_table().unwrap(); + let task = lua.create_table().unwrap(); + + // Test with no elevation + let elevation = get_elevation_config(&host, &task).unwrap(); + assert!(matches!( + elevation, + Elevation { + method: ElevationMethod::None, + as_user: None + } + )); + + // Test with elevation from task + task.set("elevate", true).unwrap(); + let elevation = get_elevation_config(&host, &task).unwrap(); + assert!(matches!( + elevation, + Elevation { + method: ElevationMethod::Sudo, + as_user: None + } + )); + + // Test with custom elevation method + task.set("elevation_method", "su").unwrap(); + let elevation = get_elevation_config(&host, &task).unwrap(); + assert!(matches!( + elevation, + Elevation { + method: ElevationMethod::Su, + as_user: None + } + )); + + // Test invalid elevation method + task.set("elevation_method", "invalid").unwrap(); + assert!(get_elevation_config(&host, &task).is_err()); + } + + #[test] + fn test_setup_ssh_session() { + let lua = create_lua().unwrap(); + let host = lua.create_table().unwrap(); + host.set("address", "localhost").unwrap(); + + // Test with default settings + let ssh = setup_ssh_session(&host).unwrap(); + assert!(ssh.known_hosts_file.is_some()); + + // Test with host key check disabled + host.set("host_key_check", false).unwrap(); + let ssh = setup_ssh_session(&host).unwrap(); + assert!(ssh.known_hosts_file.is_none()); + + // Test with custom known_hosts file + host.set("known_hosts_file", "/path/to/known_hosts") + .unwrap(); + host.set("host_key_check", true).unwrap(); + let ssh = setup_ssh_session(&host).unwrap(); + assert_eq!(ssh.known_hosts_file.unwrap(), "/path/to/known_hosts"); + + // Test with known_hosts from defaults + host.set("known_hosts_file", Value::Nil).unwrap(); + lua.load(chunk! { + komandan.defaults:set_known_hosts_file("/default/known_hosts") + }) + .exec() + .unwrap(); + let ssh = setup_ssh_session(&host).unwrap(); + assert_eq!(ssh.known_hosts_file.unwrap(), "/default/known_hosts"); + } + + #[test] + fn test_setup_environment() { + let lua = create_lua().unwrap(); + let mut ssh = SSHSession::new().unwrap(); + let defaults = lua.create_table().unwrap(); + let host = lua.create_table().unwrap(); + let task = lua.create_table().unwrap(); + + // Test with environment variables at all levels + let env_defaults = lua.create_table().unwrap(); + env_defaults.set("DEFAULT_VAR", "default_value").unwrap(); + defaults.set("env", env_defaults).unwrap(); + + let env_host = lua.create_table().unwrap(); + env_host.set("HOST_VAR", "host_value").unwrap(); + env_host.set("DEFAULT_VAR", "overridden_value").unwrap(); // Override default + host.set("env", env_host).unwrap(); + + let env_task = lua.create_table().unwrap(); + env_task.set("TASK_VAR", "task_value").unwrap(); + task.set("env", env_task).unwrap(); + + setup_environment(&mut ssh, &host, &task).unwrap(); + } } diff --git a/src/ssh.rs b/src/ssh.rs index b012ec2..5e6e3b5 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -20,12 +20,12 @@ pub enum SSHAuthMethod { } pub struct Elevation { - pub method: ElevateMethod, + pub method: ElevationMethod, pub as_user: Option, } #[derive(Debug, PartialEq)] -pub enum ElevateMethod { +pub enum ElevationMethod { None, Su, Sudo, @@ -48,7 +48,7 @@ impl SSHSession { known_hosts_file: None, env: HashMap::new(), elevation: Elevation { - method: ElevateMethod::None, + method: ElevationMethod::None, as_user: None, }, stdout: Some(String::new()), @@ -153,11 +153,11 @@ impl SSHSession { pub fn prepare_command(&mut self, command: &str) -> Result { let command = match self.elevation.method { - ElevateMethod::Su => match &self.elevation.as_user { + ElevationMethod::Su => match &self.elevation.as_user { Some(user) => format!("su {} -c '{}'", user, command), None => format!("su -c '{}'", command), }, - ElevateMethod::Sudo => match &self.elevation.as_user { + ElevationMethod::Sudo => match &self.elevation.as_user { Some(user) => format!("sudo -u {} {}", user, command), None => format!("sudo {}", command), }, @@ -374,7 +374,7 @@ mod tests { let session = SSHSession::new(); assert!(session.is_ok()); let session = session.unwrap(); - assert_eq!(session.elevation.method, ElevateMethod::None); + assert_eq!(session.elevation.method, ElevationMethod::None); assert!(session.env.is_empty()); } @@ -394,25 +394,25 @@ mod tests { assert_eq!(cmd, "ls -la"); // Test with sudo elevation - session.elevation.method = ElevateMethod::Sudo; + session.elevation.method = ElevationMethod::Sudo; session.elevation.as_user = None; let cmd = session.prepare_command("ls -la").unwrap(); assert_eq!(cmd, "sudo ls -la"); // Test with sudo elevation and user - session.elevation.method = ElevateMethod::Sudo; + session.elevation.method = ElevationMethod::Sudo; session.elevation.as_user = Some("admin".to_string()); let cmd = session.prepare_command("ls -la").unwrap(); assert_eq!(cmd, "sudo -u admin ls -la"); // Test with su elevation - session.elevation.method = ElevateMethod::Su; + session.elevation.method = ElevationMethod::Su; session.elevation.as_user = None; let cmd = session.prepare_command("ls -la").unwrap(); assert_eq!(cmd, "su -c 'ls -la'"); // Test with su elevation and user - session.elevation.method = ElevateMethod::Su; + session.elevation.method = ElevationMethod::Su; session.elevation.as_user = Some("admin".to_string()); let cmd = session.prepare_command("ls -la").unwrap(); assert_eq!(cmd, "su admin -c 'ls -la'"); diff --git a/src/util.rs b/src/util.rs index 99a3d91..389b461 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,26 +5,6 @@ use http_client::create_client_from_url; use mlua::{chunk, Error::RuntimeError, Lua, LuaSerdeExt, Table, Value}; use std::{fs::File, io::Read}; -pub fn set_defaults(lua: &Lua, data: Value) -> mlua::Result<()> { - if !data.is_table() { - return Err(RuntimeError( - "Parameter for set_defaults must be a table.".to_string(), - )); - } - - let defaults = lua - .globals() - .get::
("komandan")? - .get::
("defaults")?; - - for pair in data.as_table().unwrap().pairs() { - let (key, value): (String, Value) = pair?; - defaults.set(key, value.clone())?; - } - - Ok(()) -} - pub fn dprint(lua: &Lua, value: Value) -> mlua::Result<()> { let args = Args::parse(); if args.verbose { @@ -266,7 +246,6 @@ pub fn task_display(task: &Table) -> String { // Tests #[cfg(test)] mod tests { - use mlua::Integer; use tempfile::NamedTempFile; use crate::create_lua; @@ -458,46 +437,6 @@ mod tests { assert!(!result); } - #[test] - fn test_set_defaults() { - let lua = create_lua().unwrap(); - - // Test setting a default value - let defaults_data = lua.create_table().unwrap(); - defaults_data.set("user", "testuser").unwrap(); - set_defaults(&lua, Value::Table(defaults_data)).unwrap(); - - let defaults = lua - .globals() - .get::
("komandan") - .unwrap() - .get::
("defaults") - .unwrap(); - assert_eq!(defaults.get::("user").unwrap(), "testuser"); - - // Test setting multiple default values - let defaults_data = lua.create_table().unwrap(); - defaults_data.set("port", 2222).unwrap(); - defaults_data.set("key", "/path/to/key").unwrap(); - set_defaults(&lua, Value::Table(defaults_data)).unwrap(); - - let defaults = lua - .globals() - .get::
("komandan") - .unwrap() - .get::
("defaults") - .unwrap(); - assert_eq!(defaults.get::("port").unwrap(), 2222); - assert_eq!(defaults.get::("key").unwrap(), "/path/to/key"); - - // Test with non-table input - let result = set_defaults( - &lua, - Value::String(lua.create_string("not_a_table").unwrap()), - ); - assert!(result.is_err()); - } - #[test] fn test_parse_hosts_json_valid() { let lua = create_lua().unwrap(); diff --git a/tests/komando.rs b/tests/komando.rs index 3709f69..bfddfe5 100644 --- a/tests/komando.rs +++ b/tests/komando.rs @@ -85,9 +85,7 @@ fn test_komando_use_default_user() { let result = lua .load(chunk! { - komandan.set_defaults({ - user = "usertest", - }) + komandan.defaults:set_user("usertest") local hosts = { address = "localhost",