Skip to content

Commit

Permalink
Refactor lua creation and some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnavi committed Dec 25, 2024
1 parent 8471392 commit 7c98dff
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 66 deletions.
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use util::{
};
use validator::{validate_host, validate_task};

pub fn setup_lua_env(lua: &Lua) -> mlua::Result<()> {
pub fn create_lua() -> mlua::Result<Lua> {
let lua = Lua::new();
let args = Args::parse();

let project_dir = match args.main_file.clone() {
Expand Down Expand Up @@ -48,7 +49,7 @@ pub fn setup_lua_env(lua: &Lua) -> mlua::Result<()> {

setup_komandan_table(&lua)?;

Ok(())
Ok(lua)
}

pub fn setup_komandan_table(lua: &Lua) -> mlua::Result<()> {
Expand Down
7 changes: 2 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ mod args;

use args::Args;
use clap::Parser;
use komandan::{print_version, repl, run_main_file, setup_lua_env};
use mlua::Lua;
use komandan::{create_lua, print_version, repl, run_main_file};

fn main() -> anyhow::Result<()> {
let args = Args::parse();
Expand All @@ -13,9 +12,7 @@ fn main() -> anyhow::Result<()> {
return Ok(());
}

let lua = Lua::new();

setup_lua_env(&lua)?;
let lua = create_lua()?;

if let Some(chunk) = args.chunk.clone() {
lua.load(&chunk).eval::<()>()?;
Expand Down
64 changes: 25 additions & 39 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,27 +269,11 @@ mod tests {
use mlua::Integer;
use tempfile::NamedTempFile;

use crate::setup_lua_env;
use crate::create_lua;

use super::*;
use std::{env, fs::write, io::Write};

fn setup_lua() -> Lua {
let lua = Lua::new();

// Initialize komandan table and defaults
lua.globals()
.set("komandan", lua.create_table().unwrap())
.unwrap();

let komandan = lua.globals().get::<Table>("komandan").unwrap();
komandan
.set("defaults", lua.create_table().unwrap())
.unwrap();

lua
}

#[test]
fn test_dprint_verbose() {
// Simulate verbose flag being set
Expand All @@ -302,7 +286,7 @@ mod tests {
};
env::set_var("MOCK_ARGS", format!("{:?}", args));

let lua = setup_lua();
let lua = create_lua().unwrap();
let value = Value::String(lua.create_string("Test verbose print").unwrap());
assert!(dprint(&lua, value).is_ok());
}
Expand All @@ -319,14 +303,14 @@ mod tests {
};
env::set_var("MOCK_ARGS", format!("{:?}", args));

let lua = setup_lua();
let lua = create_lua().unwrap();
let value = Value::String(lua.create_string("Test non-verbose print").unwrap());
assert!(dprint(&lua, value).is_ok());
}

#[test]
fn test_filter_hosts_invalid_hosts_type() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = Value::Nil;
let pattern = Value::String(lua.create_string("host1").unwrap());
let result = filter_hosts(&lua, (hosts, pattern));
Expand All @@ -348,7 +332,7 @@ mod tests {

#[test]
fn test_filter_hosts_invalid_pattern() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = lua.create_table().unwrap();
hosts.set("host1", lua.create_table().unwrap()).unwrap();
let pattern = Value::Nil;
Expand All @@ -372,7 +356,7 @@ mod tests {

#[test]
fn test_filter_hosts_single_string_pattern() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = lua.create_table().unwrap();
let host_data = lua.create_table().unwrap();
host_data
Expand All @@ -389,7 +373,7 @@ mod tests {

#[test]
fn test_filter_hosts_table_pattern() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = lua.create_table().unwrap();
let host_data = lua.create_table().unwrap();
host_data
Expand All @@ -406,7 +390,7 @@ mod tests {

#[test]
fn test_filter_hosts_regex_pattern_host() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = lua.create_table().unwrap();
let host_data = lua.create_table().unwrap();
host_data
Expand All @@ -423,7 +407,7 @@ mod tests {

#[test]
fn test_filter_hosts_regex_pattern_tag() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = lua.create_table().unwrap();
let host_data = lua.create_table().unwrap();
host_data
Expand All @@ -440,7 +424,7 @@ mod tests {

#[test]
fn test_filter_hosts_invalid_hosts() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let hosts = Value::String(lua.create_string("not_a_table").unwrap());
let pattern = Value::String(lua.create_string("host1").unwrap());
let result = filter_hosts(&lua, (hosts, pattern));
Expand All @@ -449,7 +433,7 @@ mod tests {

#[test]
fn test_regex_is_match_valid_match() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let text = lua.create_string("hello world").unwrap();
let pattern = lua.create_string("hello").unwrap();
let result = regex_is_match(&lua, (text, pattern)).unwrap();
Expand All @@ -458,7 +442,7 @@ mod tests {

#[test]
fn test_regex_is_match_valid_no_match() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let text = lua.create_string("hello world").unwrap();
let pattern = lua.create_string("goodbye").unwrap();
let result = regex_is_match(&lua, (text, pattern)).unwrap();
Expand All @@ -467,7 +451,7 @@ mod tests {

#[test]
fn test_regex_is_match_invalid_regex() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let text = lua.create_string("hello world").unwrap();
let pattern = lua.create_string("[").unwrap();
let result = regex_is_match(&lua, (text, pattern)).unwrap();
Expand All @@ -476,8 +460,7 @@ mod tests {

#[test]
fn test_set_defaults() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

// Test setting a default value
let defaults_data = lua.create_table().unwrap();
Expand Down Expand Up @@ -517,7 +500,7 @@ mod tests {

#[test]
fn test_parse_hosts_json_valid() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let temp_file = NamedTempFile::new().unwrap();
let json_content = r#"[
{
Expand All @@ -538,7 +521,7 @@ mod tests {

#[test]
fn test_parse_hosts_json_invalid_path() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let lua_string = Value::String(lua.create_string("/nonexistent/path").unwrap());
let result = parse_hosts_json_file(&lua, lua_string);
assert!(result.is_err());
Expand All @@ -550,7 +533,7 @@ mod tests {

#[test]
fn test_parse_hosts_json_invalid_file() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let temp_file = NamedTempFile::new().unwrap();
temp_file
.as_file()
Expand All @@ -570,7 +553,7 @@ mod tests {

#[test]
fn test_parse_hosts_json_invalid_json() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let temp_file = NamedTempFile::new().unwrap();
write(temp_file.path(), "invalid json content").unwrap();

Expand All @@ -583,7 +566,7 @@ mod tests {

#[test]
fn test_parse_hosts_json_invalid_to_lua_value() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let temp_file = NamedTempFile::new().unwrap();
write(temp_file.path(), "true").unwrap();

Expand All @@ -595,15 +578,18 @@ mod tests {
assert!(result
.unwrap_err()
.to_string()
.contains("JSON does not contain a table"));
.contains("Failed to parse JSON file from"));
}

#[test]
fn test_parse_hosts_json_invalid_input_type() {
let lua = setup_lua();
let lua = create_lua().unwrap();
let result = parse_hosts_json_url(&lua, Value::Nil);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid src path"));
assert!(result
.unwrap_err()
.to_string()
.contains("URL must be a strin"));
}

#[test]
Expand Down
31 changes: 11 additions & 20 deletions tests/komando.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use komandan::setup_lua_env;
use mlua::{chunk, Integer, Lua, Table};
use komandan::create_lua;
use mlua::{chunk, Integer, Table};
use std::{env, io::Write};
use tempfile::NamedTempFile;

#[test]
fn test_komando_invalid_known_hosts_path() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let result = lua
.load(chunk! {
Expand All @@ -32,8 +31,7 @@ fn test_komando_invalid_known_hosts_path() {

#[test]
fn test_komando_known_hosts_check_not_match() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let result = lua
.load(chunk! {
Expand All @@ -58,8 +56,7 @@ fn test_komando_known_hosts_check_not_match() {

#[test]
fn test_komando_userauth_invalid_password() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let result = lua
.load(chunk! {
Expand All @@ -84,8 +81,7 @@ fn test_komando_userauth_invalid_password() {

#[test]
fn test_komando_use_default_user() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let result = lua
.load(chunk! {
Expand Down Expand Up @@ -113,8 +109,7 @@ fn test_komando_use_default_user() {

#[test]
fn test_komando_use_default_user_from_env() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();
env::set_var("USER", "usertest");

let result = lua
Expand All @@ -139,8 +134,7 @@ fn test_komando_use_default_user_from_env() {

#[test]
fn test_komando_no_user_specified() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();
env::remove_var("USER");

let result = lua
Expand Down Expand Up @@ -169,8 +163,7 @@ fn test_komando_no_user_specified() {

#[test]
fn test_komando_simple_cmd() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let result_table = lua
.load(chunk! {
Expand Down Expand Up @@ -198,8 +191,7 @@ fn test_komando_simple_cmd() {

#[test]
fn test_komando_simple_script() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let result_table = lua
.load(chunk! {
Expand Down Expand Up @@ -228,8 +220,7 @@ fn test_komando_simple_script() {

#[test]
fn test_komando_script_from_file() {
let lua = Lua::new();
setup_lua_env(&lua).unwrap();
let lua = create_lua().unwrap();

let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "echo hello").unwrap();
Expand Down

0 comments on commit 7c98dff

Please sign in to comment.