Skip to content

Commit

Permalink
Merge pull request #4 from samuelba/feature/refactor-and-unit-testing
Browse files Browse the repository at this point in the history
refactor and add many unit tests
  • Loading branch information
samuelba authored Apr 16, 2023
2 parents 991c583 + 3fa2e35 commit 652e46c
Show file tree
Hide file tree
Showing 8 changed files with 951 additions and 335 deletions.
201 changes: 133 additions & 68 deletions Cargo.lock

Large diffs are not rendered by default.

23 changes: 11 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
[package]
name = "traefik_crowdsec_bouncer"
version = "0.2.2"
version = "0.2.3"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
actix-web = "4"
actix-rt = "2"
actix-web = { version = "4.3" }
actix-rt = { version = "2.8" }
chrono = { version = "0.4", features = ["time"] }
custom_error = "1.9"
env_logger = "0.10"
ip_network_table-deps-treebitmap = "0.5"
log = "0.4"
parse_duration = "2"
custom_error = { version = "1.9" }
env_logger = { version = "0.10" }
ip_network_table-deps-treebitmap = { version = "0.5" }
log = { version = "0.4" }
parse_duration = { version = "2.1" }
reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
url = "2.3"
serde_json = { version = "1.0" }
url = { version = "2.3" }

[dev-dependencies]
mockito = "1.0.2"
tokio = { version = "1", features = ["full"] }
mockito = { version = "1.0" }
296 changes: 166 additions & 130 deletions src/bouncer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ use parse_duration::parse;
use crate::config::{Config, CrowdSecMode};
use crate::constants::{APPLICATION_JSON, TEXT_PLAIN};
use crate::crowdsec::get_decision;
use crate::errors::TraefikError;
use crate::types::{CacheAttributes, HealthStatus};

#[cfg(test)]
mod tests;

pub struct TraefikHeaders {
ip: String,
}

fn forbidden_response(ip: Option<String>) -> HttpResponse {
if let Some(ip) = ip {
info!("IP: {} is not allowed", ip);
Expand All @@ -39,166 +47,194 @@ fn set_health_status(health_status: Arc<Mutex<HealthStatus>>, healthy: bool) {
}
}

/// Authenticate an IP address.
/// # Arguments
/// * `config` - The configuration.
/// * `health_status` - The health status.
/// * `ipv4_data` - The IPv4 lookup table.
/// * `request` - The HTTP request.
/// # Returns
/// * `HttpResponse` - The HTTP response. Either `Ok` or `Forbidden`.
#[get("/api/v1/forwardAuth")]
pub async fn authenticate(
config: Data<Config>,
health_status: Data<Arc<Mutex<HealthStatus>>>,
ipv4_data: Data<Arc<Mutex<IpLookupTable<Ipv4Addr, CacheAttributes>>>>,
request: HttpRequest,
) -> HttpResponse {
// Get the IP address from the X-Forwarded-For header.
let req_ip_str = if let Some(header_value) = request.headers().get("X-Forwarded-For") {
if let Ok(header_value) = header_value.to_str() {
header_value.to_string()
fn extract_headers(request: &HttpRequest) -> Result<TraefikHeaders, TraefikError> {
let ip = if let Some(ip) = request.headers().get("X-Forwarded-For") {
if let Ok(ip) = ip.to_str() {
ip.to_string()
} else {
warn!("Could not convert 'X-Forwarded-For' to string. Block request.");
return forbidden_response(None);
return Err(TraefikError::BadHeaders);
}
} else {
warn!("No 'X-Forwarded-For' in the request header. Block request.");
return forbidden_response(None);
return Err(TraefikError::BadHeaders);
};

match config.crowdsec_mode {
CrowdSecMode::Stream => {
return if let Ok(ipv4_table) = ipv4_data.lock() {
let req_ip = Ipv4Addr::from_str(&req_ip_str);
match req_ip {
Ok(ip) => {
if ipv4_table.longest_match(ip).is_some() {
forbidden_response(Some(req_ip_str))
} else {
allowed_response(Some(req_ip_str))
}
}
Err(_) => forbidden_response(Some(req_ip_str)),
Ok(TraefikHeaders { ip })
}

pub async fn authenticate_stream_mode(
headers: TraefikHeaders,
ipv4_data: Data<Arc<Mutex<IpLookupTable<Ipv4Addr, CacheAttributes>>>>,
) -> HttpResponse {
if let Ok(ipv4_table) = ipv4_data.lock() {
let req_ip = Ipv4Addr::from_str(&headers.ip);
match req_ip {
Ok(ip) => {
if ipv4_table.longest_match(ip).is_some() {
forbidden_response(Some(headers.ip))
} else {
allowed_response(Some(headers.ip))
}
} else {
warn!("Could not lock the IPv4 lookup table. Block request.");
forbidden_response(Some(req_ip_str))
}
Err(_) => forbidden_response(Some(headers.ip)),
}
CrowdSecMode::Live => {
let req_ip = Ipv4Addr::from_str(&req_ip_str);
match req_ip {
Ok(ip) => {
// Check if IP is in cache.
// If yes, check if it is expired.
// If not, return the cached value.
if let Ok(ipv4_table) = ipv4_data.lock() {
if let Some(cache_attributes) = ipv4_table.exact_match(ip, 32) {
if cache_attributes.expiration_time
> chrono::Utc::now().timestamp_millis()
{
return if cache_attributes.allowed {
allowed_response(Some(req_ip_str))
} else {
forbidden_response(Some(req_ip_str))
};
}
}
}

// IP not in cache or expired.
// Call CrowdSec API.
// Update cache.
return match get_decision(
&config.crowdsec_live_url,
&config.crowdsec_api_key,
&req_ip_str,
)
.await
{
Ok(decision) => {
set_health_status(health_status.get_ref().clone(), true);
match decision {
Some(decision) => {
// If the decisions duration is smaller than the cache TTL, use it instead.
let ttl = if let Ok(duration) = parse(&decision.duration) {
min(duration.as_millis() as i64, config.crowdsec_cache_ttl)
} else {
config.crowdsec_cache_ttl
};
} else {
warn!("Could not lock the IPv4 lookup table. Block request.");
forbidden_response(Some(headers.ip))
}
}

// Update cache.
if let Ok(mut ipv4_table) = ipv4_data.lock() {
ipv4_table.insert(
ip,
32,
CacheAttributes {
allowed: false,
expiration_time: chrono::Utc::now()
.timestamp_millis()
+ ttl,
},
);
}
forbidden_response(Some(req_ip_str))
}
None => {
// Update cache.
if let Ok(mut ipv4_table) = ipv4_data.lock() {
ipv4_table.insert(
ip,
32,
CacheAttributes {
allowed: true,
expiration_time: chrono::Utc::now()
.timestamp_millis()
+ config.crowdsec_cache_ttl,
},
);
}
allowed_response(Some(req_ip_str))
}
}
}
Err(err) => {
info!(
"Could not call API. IP: {} is not allowed. Error {}",
req_ip_str, err
);
set_health_status(health_status.get_ref().clone(), false);
forbidden_response(None)
}
};
pub async fn authenticate_live_mode(
headers: TraefikHeaders,
config: Data<Config>,
health_status: Data<Arc<Mutex<HealthStatus>>>,
ipv4_data: Data<Arc<Mutex<IpLookupTable<Ipv4Addr, CacheAttributes>>>>,
) -> HttpResponse {
let req_ip = Ipv4Addr::from_str(&headers.ip);
match req_ip {
Ok(ip) => {
// Check if IP is in cache.
// If yes, check if it is expired.
// If not, return the cached value.
if let Ok(ipv4_table) = ipv4_data.lock() {
if let Some(cache_attributes) = ipv4_table.exact_match(ip, 32) {
if cache_attributes.expiration_time > chrono::Utc::now().timestamp_millis() {
return if cache_attributes.allowed {
allowed_response(Some(headers.ip))
} else {
forbidden_response(Some(headers.ip))
};
}
}
Err(_) => forbidden_response(Some(req_ip_str)),
}
}
CrowdSecMode::None => {
match get_decision(

// IP not in cache or expired.
// Call CrowdSec API.
// Update cache.
return match get_decision(
&config.crowdsec_live_url,
&config.crowdsec_api_key,
&req_ip_str,
&headers.ip,
)
.await
{
Ok(decision) => {
set_health_status(health_status.get_ref().clone(), true);
match decision {
Some(_) => forbidden_response(Some(req_ip_str)),
None => allowed_response(Some(req_ip_str)),
Some(decision) => {
// If the decisions duration is smaller than the cache TTL, use it instead.
let ttl = if let Ok(duration) = parse(&decision.duration) {
min(duration.as_millis() as i64, config.crowdsec_cache_ttl)
} else {
config.crowdsec_cache_ttl
};

// Update cache.
if let Ok(mut ipv4_table) = ipv4_data.lock() {
ipv4_table.insert(
ip,
32,
CacheAttributes {
allowed: false,
expiration_time: chrono::Utc::now().timestamp_millis()
+ ttl,
},
);
}
forbidden_response(Some(headers.ip))
}
None => {
// Update cache.
if let Ok(mut ipv4_table) = ipv4_data.lock() {
ipv4_table.insert(
ip,
32,
CacheAttributes {
allowed: true,
expiration_time: chrono::Utc::now().timestamp_millis()
+ config.crowdsec_cache_ttl,
},
);
}
allowed_response(Some(headers.ip))
}
}
}
Err(err) => {
info!(
"Could not call API. IP: {} is not allowed. Error {}",
req_ip_str, err
headers.ip, err
);
set_health_status(health_status.get_ref().clone(), false);
forbidden_response(None)
}
};
}
Err(_) => forbidden_response(Some(headers.ip)),
}
}

pub async fn authenticate_none_mode(
headers: TraefikHeaders,
config: Data<Config>,
health_status: Data<Arc<Mutex<HealthStatus>>>,
) -> HttpResponse {
match get_decision(
&config.crowdsec_live_url,
&config.crowdsec_api_key,
&headers.ip,
)
.await
{
Ok(decision) => {
set_health_status(health_status.get_ref().clone(), true);
match decision {
Some(_) => forbidden_response(Some(headers.ip)),
None => allowed_response(Some(headers.ip)),
}
}
Err(err) => {
info!(
"Could not call API. IP: {} is not allowed. Error {}",
headers.ip, err
);
set_health_status(health_status.get_ref().clone(), false);
forbidden_response(None)
}
}
}

/// Authenticate an IP address.
/// # Arguments
/// * `config` - The configuration.
/// * `health_status` - The health status.
/// * `ipv4_data` - The IPv4 lookup table.
/// * `request` - The HTTP request.
/// # Returns
/// * `HttpResponse` - The HTTP response. Either `Ok` or `Forbidden`.
#[get("/api/v1/forwardAuth")]
pub async fn authenticate(
config: Data<Config>,
health_status: Data<Arc<Mutex<HealthStatus>>>,
ipv4_data: Data<Arc<Mutex<IpLookupTable<Ipv4Addr, CacheAttributes>>>>,
request: HttpRequest,
) -> HttpResponse {
let headers = match extract_headers(&request) {
Ok(header) => header,
Err(err) => {
warn!(
"Could not get headers from request. Block request. Error: {}",
err
);
return forbidden_response(None);
}
};

match config.crowdsec_mode {
CrowdSecMode::Stream => authenticate_stream_mode(headers, ipv4_data).await,
CrowdSecMode::Live => {
authenticate_live_mode(headers, config, health_status, ipv4_data).await
}
CrowdSecMode::None => authenticate_none_mode(headers, config, health_status).await,
}
}

Expand Down
Loading

0 comments on commit 652e46c

Please sign in to comment.