Skip to content

Commit

Permalink
Merge pull request #8 from bits0rcerer/main
Browse files Browse the repository at this point in the history
return error instead of panic
  • Loading branch information
jwhb authored Aug 12, 2023
2 parents 3527327 + fd01cfe commit e8ad6fa
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 35 deletions.
37 changes: 29 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ exclude = [
]

[dependencies]
serde = {version = "1.0.137", features = ["derive"]}
serde_json = {version = "1.0.81"}
serde = { version = "1.0.137", features = ["derive"] }
serde_json = { version = "1.0.81" }
serde_path_to_error = "0.1"
strum = "0.24"
strum_macros = "0.24"
thiserror = "1.0"

[build-dependencies]
110 changes: 88 additions & 22 deletions src/helper.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,48 @@
use std::string::FromUtf8Error;
use std::{
io::{self, Write},
process::{Command, Stdio},
};

use thiserror::Error;

use crate::schema::Nftables;

const NFT_EXECUTABLE: &str = "nft"; // search in PATH

pub fn get_current_ruleset(program: Option<&str>, args: Option<Vec<&str>>) -> Nftables {
let output = get_current_ruleset_raw(program, args);
let nftables: Nftables = serde_json::from_str(&output).unwrap();
nftables
#[derive(Error, Debug)]
pub enum NftablesError {
#[error("unable to execute {program}: {inner}")]
NftExecution { program: String, inner: io::Error },
#[error("{program}'s output contained invalid utf8: {inner}")]
NftOutputEncoding {
program: String,
inner: FromUtf8Error,
},
#[error("got invalid json: {0}")]
NftInvalidJson(serde_json::Error),
#[error("{program} did not return successfully while {hint}")]
NftFailed {
program: String,
hint: String,
stdout: String,
stderr: String,
},
}

pub fn get_current_ruleset_raw(program: Option<&str>, args: Option<Vec<&str>>) -> String {
let nft_executable: &str = program.unwrap_or(NFT_EXECUTABLE);
let mut nft_cmd = get_command(Some(nft_executable));
pub fn get_current_ruleset(
program: Option<&str>,
args: Option<Vec<&str>>,
) -> Result<Nftables, NftablesError> {
let output = get_current_ruleset_raw(program, args)?;
serde_json::from_str(&output).map_err(NftablesError::NftInvalidJson)
}

pub fn get_current_ruleset_raw(
program: Option<&str>,
args: Option<Vec<&str>>,
) -> Result<String, NftablesError> {
let mut nft_cmd = get_command(program);
let default_args = ["-j", "list", "ruleset"];
let args: Vec<&str> = match args {
Some(mut args) => {
Expand All @@ -24,21 +51,34 @@ pub fn get_current_ruleset_raw(program: Option<&str>, args: Option<Vec<&str>>) -
}
None => default_args.to_vec(),
};
let output = nft_cmd
let process_result = nft_cmd
.args(args)
.output()
.expect("nft command failed to start");
if !output.status.success() {
panic!("nft failed to show the current ruleset");
.map_err(|e| NftablesError::NftExecution {
inner: e,
program: format!("{}", nft_cmd.get_program().to_str().unwrap()),
})?;

let stdout = read_output(&nft_cmd, process_result.stdout)?;

if !process_result.status.success() {
let stderr = read_output(&nft_cmd, process_result.stderr)?;

return Err(NftablesError::NftFailed {
program: format!("{}", nft_cmd.get_program().to_str().unwrap()),
hint: "getting the current ruleset".to_string(),
stdout,
stderr,
});
}
String::from_utf8(output.stdout).expect("failed to decode nft output as utf8")
Ok(stdout)
}

pub fn apply_ruleset(
nftables: &Nftables,
program: Option<&str>,
args: Option<Vec<&str>>,
) -> io::Result<()> {
) -> Result<(), NftablesError> {
let nftables = serde_json::to_string(nftables).expect("failed to serialize Nftables struct");
apply_ruleset_raw(nftables, program, args)
}
Expand All @@ -47,9 +87,8 @@ pub fn apply_ruleset_raw(
payload: String,
program: Option<&str>,
args: Option<Vec<&str>>,
) -> io::Result<()> {
let nft_executable: &str = program.unwrap_or(NFT_EXECUTABLE);
let mut nft_cmd = get_command(Some(nft_executable));
) -> Result<(), NftablesError> {
let mut nft_cmd = get_command(program);
let default_args = ["-j", "-f", "-"];
let args: Vec<&str> = match args {
Some(mut args) => {
Expand All @@ -62,23 +101,50 @@ pub fn apply_ruleset_raw(
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
.spawn()
.map_err(|e| NftablesError::NftExecution {
program: format!("{}", nft_cmd.get_program().to_str().unwrap()),
inner: e,
})?;

let mut stdin = process.stdin.take().unwrap();
stdin.write_all(payload.as_bytes())?;
stdin
.write_all(payload.as_bytes())
.map_err(|e| NftablesError::NftExecution {
program: format!("{}", nft_cmd.get_program().to_str().unwrap()),
inner: e,
})?;
drop(stdin);

let result = process.wait_with_output();
match result {
Ok(output) => {
assert!(output.status.success());
Ok(())
Ok(output) if output.status.success() => Ok(()),
Ok(process_result) => {
let stdout = read_output(&nft_cmd, process_result.stdout)?;
let stderr = read_output(&nft_cmd, process_result.stderr)?;

Err(NftablesError::NftFailed {
program: format!("{}", nft_cmd.get_program().to_str().unwrap()),
hint: "applying ruleset".to_string(),
stdout,
stderr,
})
}
Err(err) => Err(err),
Err(e) => Err(NftablesError::NftExecution {
program: format!("{}", nft_cmd.get_program().to_str().unwrap()),
inner: e,
}),
}
}

fn get_command(program: Option<&str>) -> Command {
let nft_executable: &str = program.unwrap_or(NFT_EXECUTABLE);
Command::new(nft_executable)
}

fn read_output(cmd: &Command, bytes: Vec<u8>) -> Result<String, NftablesError> {
String::from_utf8(bytes).map_err(|e| NftablesError::NftOutputEncoding {
inner: e,
program: format!("{}", cmd.get_program().to_str().unwrap()),
})
}
9 changes: 6 additions & 3 deletions tests/helper_tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use nftables::{batch::Batch, helper, schema, types, expr};
use nftables::{batch::Batch, expr, helper, schema, types};

#[test]
#[ignore]
/// Reads current ruleset from nftables and reads it to `Nftables` Rust struct.
fn test_list_ruleset() {
helper::get_current_ruleset(None, None);
helper::get_current_ruleset(None, None).unwrap();
}

#[test]
Expand Down Expand Up @@ -42,7 +42,10 @@ fn example_ruleset() -> schema::Nftables {
family: types::NfFamily::IP,
table: table_name,
name: set_name,
elem: vec![expr::Expression::String("127.0.0.1".to_string()), expr::Expression::String("127.0.0.2".to_string())],
elem: vec![
expr::Expression::String("127.0.0.1".to_string()),
expr::Expression::String("127.0.0.2".to_string()),
],
}));
batch.delete(schema::NfListObject::Table(schema::Table::new(
types::NfFamily::IP,
Expand Down

0 comments on commit e8ad6fa

Please sign in to comment.