diff --git a/Cargo.lock b/Cargo.lock index 51c69e2..bc5fc07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,9 +416,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" @@ -478,7 +478,7 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn 2.0.12", + "syn 2.0.15", ] [[package]] @@ -495,7 +495,7 @@ checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.12", + "syn 2.0.15", ] [[package]] @@ -545,13 +545,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" dependencies = [ "errno-dragonfly", "libc", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -669,7 +669,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.12", + "syn 2.0.15", ] [[package]] @@ -714,9 +714,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -725,9 +725,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.16" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" +checksum = "66b91535aa35fea1523ad1b86cb6b53c28e0ae566ba4a460f4457e936cad7c6f" dependencies = [ "bytes", "fnv", @@ -814,9 +814,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.25" +version = "0.14.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" +checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4" dependencies = [ "bytes", "futures-channel", @@ -851,9 +851,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.54" +version = "0.1.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c17cc76786e99f8d2f055c11159e7f0091c42474dcc3189fbab96072e873e6d" +checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -904,13 +904,13 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09270fd4fa1111bc614ed2246c7ef56239a3063d5be0d1ec3b589c505d400aeb" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ "hermit-abi 0.3.1", "libc", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -927,14 +927,14 @@ checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" [[package]] name = "is-terminal" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256017f749ab3117e93acb91063009e1f1bb56d03965b14c2c8df4eb02c524d8" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -975,9 +975,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.140" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" [[package]] name = "link-cplusplus" @@ -1197,9 +1197,9 @@ checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "openssl" -version = "0.10.49" +version = "0.10.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2f106ab837a24e03672c59b1239669a0596406ff657c3c0835b6b7f0f35a33" +checksum = "7e30d8bc91859781f0a943411186324d580f2bbeb71b452fe91ae344806af3f1" dependencies = [ "bitflags", "cfg-if", @@ -1218,7 +1218,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.12", + "syn 2.0.15", ] [[package]] @@ -1229,9 +1229,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.84" +version = "0.9.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a20eace9dc2d82904039cb76dcf50fb1a0bba071cfd1629720b5d6f1ddba0fa" +checksum = "0d3d193fb1488ad46ffe3aaabc912cc931d02ee8518fe2959aea8ef52718b0c0" dependencies = [ "cc", "libc", @@ -1311,9 +1311,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.54" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e472a104799c74b514a57226160104aa483546de37e839ec50e3c2e41dd87534" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] @@ -1440,16 +1440,16 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.6" +version = "0.37.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d097081ed288dfe45699b72f5b5d648e5f15d64d900c7080273baa20c16a6849" +checksum = "85597d61f83914ddeba6a47b3b8ffe7365107221c2e557ed94426489fefb5f77" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -1510,29 +1510,29 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.159" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" +checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.159" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" +checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.12", + "syn 2.0.15", ] [[package]] name = "serde_json" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ "itoa", "ryu", @@ -1615,9 +1615,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.12" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79d9531f94112cfc3e4c8f5f02cb2b58f72c97b7efd85f70203cc6d8efda5927" +checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822" dependencies = [ "proc-macro2", "quote", @@ -1726,7 +1726,7 @@ checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.12", + "syn 2.0.15", ] [[package]] @@ -1782,7 +1782,7 @@ dependencies = [ [[package]] name = "traefik_crowdsec_bouncer" -version = "0.2.2" +version = "0.2.3" dependencies = [ "actix-rt", "actix-web", @@ -1796,7 +1796,6 @@ dependencies = [ "reqwest", "serde", "serde_json", - "tokio", "url", ] @@ -1993,11 +1992,11 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.46.0" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdacb41e6a96a052c6cb63a144f24900236121c6f63f4f8219fef5977ecb0c25" +checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets", + "windows-targets 0.48.0", ] [[package]] @@ -2006,13 +2005,13 @@ version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", ] [[package]] @@ -2021,7 +2020,16 @@ version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-targets", + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", ] [[package]] @@ -2030,13 +2038,28 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] [[package]] @@ -2045,42 +2068,84 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + [[package]] name = "windows_i686_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + [[package]] name = "windows_i686_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "winreg" version = "0.10.1" @@ -2101,9 +2166,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "6.0.4+zstd.1.5.4" +version = "6.0.5+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7afb4b54b8910cf5447638cb54bf4e8a65cbedd783af98b98c62ffe91f185543" +checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" dependencies = [ "libc", "zstd-sys", @@ -2111,9 +2176,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.7+zstd.1.5.4" +version = "2.0.8+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" +checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" dependencies = [ "cc", "libc", diff --git a/Cargo.toml b/Cargo.toml index 28610d4..3f9dfd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,24 +1,23 @@ [package] name = "traefik_crowdsec_bouncer" -version = "0.2.2" +version = "0.2.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -actix-web = "4" -actix-rt = "2" +actix-web = { version = "4.3" } +actix-rt = { version = "2.8" } chrono = { version = "0.4", features = ["time"] } -custom_error = "1.9" -env_logger = "0.10" -ip_network_table-deps-treebitmap = "0.5" -log = "0.4" -parse_duration = "2" +custom_error = { version = "1.9" } +env_logger = { version = "0.10" } +ip_network_table-deps-treebitmap = { version = "0.5" } +log = { version = "0.4" } +parse_duration = { version = "2.1" } reqwest = { version = "0.11", features = ["json"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -url = "2.3" +serde_json = { version = "1.0" } +url = { version = "2.3" } [dev-dependencies] -mockito = "1.0.2" -tokio = { version = "1", features = ["full"] } +mockito = { version = "1.0" } diff --git a/src/bouncer.rs b/src/bouncer.rs index 20d8db0..1f267f7 100644 --- a/src/bouncer.rs +++ b/src/bouncer.rs @@ -15,8 +15,16 @@ use parse_duration::parse; use crate::config::{Config, CrowdSecMode}; use crate::constants::{APPLICATION_JSON, TEXT_PLAIN}; use crate::crowdsec::get_decision; +use crate::errors::TraefikError; use crate::types::{CacheAttributes, HealthStatus}; +#[cfg(test)] +mod tests; + +pub struct TraefikHeaders { + ip: String, +} + fn forbidden_response(ip: Option) -> HttpResponse { if let Some(ip) = ip { info!("IP: {} is not allowed", ip); @@ -39,166 +47,194 @@ fn set_health_status(health_status: Arc>, healthy: bool) { } } -/// Authenticate an IP address. -/// # Arguments -/// * `config` - The configuration. -/// * `health_status` - The health status. -/// * `ipv4_data` - The IPv4 lookup table. -/// * `request` - The HTTP request. -/// # Returns -/// * `HttpResponse` - The HTTP response. Either `Ok` or `Forbidden`. -#[get("/api/v1/forwardAuth")] -pub async fn authenticate( - config: Data, - health_status: Data>>, - ipv4_data: Data>>>, - request: HttpRequest, -) -> HttpResponse { - // Get the IP address from the X-Forwarded-For header. - let req_ip_str = if let Some(header_value) = request.headers().get("X-Forwarded-For") { - if let Ok(header_value) = header_value.to_str() { - header_value.to_string() +fn extract_headers(request: &HttpRequest) -> Result { + let ip = if let Some(ip) = request.headers().get("X-Forwarded-For") { + if let Ok(ip) = ip.to_str() { + ip.to_string() } else { - warn!("Could not convert 'X-Forwarded-For' to string. Block request."); - return forbidden_response(None); + return Err(TraefikError::BadHeaders); } } else { - warn!("No 'X-Forwarded-For' in the request header. Block request."); - return forbidden_response(None); + return Err(TraefikError::BadHeaders); }; - match config.crowdsec_mode { - CrowdSecMode::Stream => { - return if let Ok(ipv4_table) = ipv4_data.lock() { - let req_ip = Ipv4Addr::from_str(&req_ip_str); - match req_ip { - Ok(ip) => { - if ipv4_table.longest_match(ip).is_some() { - forbidden_response(Some(req_ip_str)) - } else { - allowed_response(Some(req_ip_str)) - } - } - Err(_) => forbidden_response(Some(req_ip_str)), + Ok(TraefikHeaders { ip }) +} + +pub async fn authenticate_stream_mode( + headers: TraefikHeaders, + ipv4_data: Data>>>, +) -> HttpResponse { + if let Ok(ipv4_table) = ipv4_data.lock() { + let req_ip = Ipv4Addr::from_str(&headers.ip); + match req_ip { + Ok(ip) => { + if ipv4_table.longest_match(ip).is_some() { + forbidden_response(Some(headers.ip)) + } else { + allowed_response(Some(headers.ip)) } - } else { - warn!("Could not lock the IPv4 lookup table. Block request."); - forbidden_response(Some(req_ip_str)) } + Err(_) => forbidden_response(Some(headers.ip)), } - CrowdSecMode::Live => { - let req_ip = Ipv4Addr::from_str(&req_ip_str); - match req_ip { - Ok(ip) => { - // Check if IP is in cache. - // If yes, check if it is expired. - // If not, return the cached value. - if let Ok(ipv4_table) = ipv4_data.lock() { - if let Some(cache_attributes) = ipv4_table.exact_match(ip, 32) { - if cache_attributes.expiration_time - > chrono::Utc::now().timestamp_millis() - { - return if cache_attributes.allowed { - allowed_response(Some(req_ip_str)) - } else { - forbidden_response(Some(req_ip_str)) - }; - } - } - } - - // IP not in cache or expired. - // Call CrowdSec API. - // Update cache. - return match get_decision( - &config.crowdsec_live_url, - &config.crowdsec_api_key, - &req_ip_str, - ) - .await - { - Ok(decision) => { - set_health_status(health_status.get_ref().clone(), true); - match decision { - Some(decision) => { - // If the decisions duration is smaller than the cache TTL, use it instead. - let ttl = if let Ok(duration) = parse(&decision.duration) { - min(duration.as_millis() as i64, config.crowdsec_cache_ttl) - } else { - config.crowdsec_cache_ttl - }; + } else { + warn!("Could not lock the IPv4 lookup table. Block request."); + forbidden_response(Some(headers.ip)) + } +} - // Update cache. - if let Ok(mut ipv4_table) = ipv4_data.lock() { - ipv4_table.insert( - ip, - 32, - CacheAttributes { - allowed: false, - expiration_time: chrono::Utc::now() - .timestamp_millis() - + ttl, - }, - ); - } - forbidden_response(Some(req_ip_str)) - } - None => { - // Update cache. - if let Ok(mut ipv4_table) = ipv4_data.lock() { - ipv4_table.insert( - ip, - 32, - CacheAttributes { - allowed: true, - expiration_time: chrono::Utc::now() - .timestamp_millis() - + config.crowdsec_cache_ttl, - }, - ); - } - allowed_response(Some(req_ip_str)) - } - } - } - Err(err) => { - info!( - "Could not call API. IP: {} is not allowed. Error {}", - req_ip_str, err - ); - set_health_status(health_status.get_ref().clone(), false); - forbidden_response(None) - } - }; +pub async fn authenticate_live_mode( + headers: TraefikHeaders, + config: Data, + health_status: Data>>, + ipv4_data: Data>>>, +) -> HttpResponse { + let req_ip = Ipv4Addr::from_str(&headers.ip); + match req_ip { + Ok(ip) => { + // Check if IP is in cache. + // If yes, check if it is expired. + // If not, return the cached value. + if let Ok(ipv4_table) = ipv4_data.lock() { + if let Some(cache_attributes) = ipv4_table.exact_match(ip, 32) { + if cache_attributes.expiration_time > chrono::Utc::now().timestamp_millis() { + return if cache_attributes.allowed { + allowed_response(Some(headers.ip)) + } else { + forbidden_response(Some(headers.ip)) + }; + } } - Err(_) => forbidden_response(Some(req_ip_str)), } - } - CrowdSecMode::None => { - match get_decision( + + // IP not in cache or expired. + // Call CrowdSec API. + // Update cache. + return match get_decision( &config.crowdsec_live_url, &config.crowdsec_api_key, - &req_ip_str, + &headers.ip, ) .await { Ok(decision) => { set_health_status(health_status.get_ref().clone(), true); match decision { - Some(_) => forbidden_response(Some(req_ip_str)), - None => allowed_response(Some(req_ip_str)), + Some(decision) => { + // If the decisions duration is smaller than the cache TTL, use it instead. + let ttl = if let Ok(duration) = parse(&decision.duration) { + min(duration.as_millis() as i64, config.crowdsec_cache_ttl) + } else { + config.crowdsec_cache_ttl + }; + + // Update cache. + if let Ok(mut ipv4_table) = ipv4_data.lock() { + ipv4_table.insert( + ip, + 32, + CacheAttributes { + allowed: false, + expiration_time: chrono::Utc::now().timestamp_millis() + + ttl, + }, + ); + } + forbidden_response(Some(headers.ip)) + } + None => { + // Update cache. + if let Ok(mut ipv4_table) = ipv4_data.lock() { + ipv4_table.insert( + ip, + 32, + CacheAttributes { + allowed: true, + expiration_time: chrono::Utc::now().timestamp_millis() + + config.crowdsec_cache_ttl, + }, + ); + } + allowed_response(Some(headers.ip)) + } } } Err(err) => { info!( "Could not call API. IP: {} is not allowed. Error {}", - req_ip_str, err + headers.ip, err ); set_health_status(health_status.get_ref().clone(), false); forbidden_response(None) } + }; + } + Err(_) => forbidden_response(Some(headers.ip)), + } +} + +pub async fn authenticate_none_mode( + headers: TraefikHeaders, + config: Data, + health_status: Data>>, +) -> HttpResponse { + match get_decision( + &config.crowdsec_live_url, + &config.crowdsec_api_key, + &headers.ip, + ) + .await + { + Ok(decision) => { + set_health_status(health_status.get_ref().clone(), true); + match decision { + Some(_) => forbidden_response(Some(headers.ip)), + None => allowed_response(Some(headers.ip)), } } + Err(err) => { + info!( + "Could not call API. IP: {} is not allowed. Error {}", + headers.ip, err + ); + set_health_status(health_status.get_ref().clone(), false); + forbidden_response(None) + } + } +} + +/// Authenticate an IP address. +/// # Arguments +/// * `config` - The configuration. +/// * `health_status` - The health status. +/// * `ipv4_data` - The IPv4 lookup table. +/// * `request` - The HTTP request. +/// # Returns +/// * `HttpResponse` - The HTTP response. Either `Ok` or `Forbidden`. +#[get("/api/v1/forwardAuth")] +pub async fn authenticate( + config: Data, + health_status: Data>>, + ipv4_data: Data>>>, + request: HttpRequest, +) -> HttpResponse { + let headers = match extract_headers(&request) { + Ok(header) => header, + Err(err) => { + warn!( + "Could not get headers from request. Block request. Error: {}", + err + ); + return forbidden_response(None); + } + }; + + match config.crowdsec_mode { + CrowdSecMode::Stream => authenticate_stream_mode(headers, ipv4_data).await, + CrowdSecMode::Live => { + authenticate_live_mode(headers, config, health_status, ipv4_data).await + } + CrowdSecMode::None => authenticate_none_mode(headers, config, health_status).await, } } diff --git a/src/bouncer/tests.rs b/src/bouncer/tests.rs new file mode 100644 index 0000000..d82cc31 --- /dev/null +++ b/src/bouncer/tests.rs @@ -0,0 +1,371 @@ +use super::*; +use actix_web::{http::header, test}; + +#[test] +async fn test_extract_headers() { + let req = test::TestRequest::default() + .insert_header(header::ContentType::plaintext()) + .insert_header(("X-Forwarded-For", "192.168.0.1")) + .to_http_request(); + + let headers = extract_headers(&req).unwrap(); + assert_eq!("192.168.0.1", headers.ip); +} + +#[test] +async fn test_extract_headers_missing_headers() { + let req = test::TestRequest::default() + .insert_header(header::ContentType::plaintext()) + .to_http_request(); + + assert!(extract_headers(&req).is_err()); +} + +#[test] +async fn test_extract_headers_invalid_headers() { + let req = test::TestRequest::default() + .insert_header(header::ContentType::plaintext()) + .insert_header(( + "X-Forwarded-For", + header::HeaderValue::from_bytes(b"\xFF").unwrap(), + )) + .to_http_request(); + + assert!(extract_headers(&req).is_err()); +} + +#[test] +async fn test_authenticate_stream_mode() { + let mut ipv4_table = IpLookupTable::new(); + ipv4_table.insert( + Ipv4Addr::from_str("172.16.0.0").unwrap(), + 16, + CacheAttributes::new(false, 0), + ); + ipv4_table.insert( + Ipv4Addr::from_str("192.168.0.1").unwrap(), + 32, + CacheAttributes::new(false, 0), + ); + let ipv4_data = Data::new(Arc::new(Mutex::new(ipv4_table))); + + assert_eq!( + 200, + authenticate_stream_mode( + TraefikHeaders { + ip: "1.1.1.1".to_string(), + }, + ipv4_data.clone(), + ) + .await + .status() + ); + assert_eq!( + 403, + authenticate_stream_mode( + TraefikHeaders { + ip: "172.16.5.5".to_string(), + }, + ipv4_data.clone(), + ) + .await + .status() + ); + assert_eq!( + 403, + authenticate_stream_mode( + TraefikHeaders { + ip: "192.168.0.1".to_string(), + }, + ipv4_data.clone(), + ) + .await + .status() + ); +} + +#[test] +async fn test_authenticate_live_mode_from_cache() { + // Set up test data. + let config = Data::new(Config { + crowdsec_live_url: "".to_string(), + crowdsec_stream_url: "".to_string(), + crowdsec_api_key: "".to_string(), + crowdsec_mode: CrowdSecMode::Live, + crowdsec_cache_ttl: 60000, + stream_interval: 0, + port: 0, + }); + let health_status = Data::new(Arc::new(Mutex::new(HealthStatus { + live_status: true, + stream_status: true, + }))); + let ipv4_data = Data::new(Arc::new(Mutex::new(IpLookupTable::< + Ipv4Addr, + CacheAttributes, + >::new()))); + if let Ok(mut ipv4_table) = ipv4_data.lock() { + ipv4_table.insert( + Ipv4Addr::from_str("172.16.0.1").unwrap(), + 32, + CacheAttributes { + allowed: true, + expiration_time: chrono::Utc::now().timestamp_millis() + 60000, + }, + ); + ipv4_table.insert( + Ipv4Addr::from_str("172.16.0.2").unwrap(), + 32, + CacheAttributes { + allowed: false, + expiration_time: chrono::Utc::now().timestamp_millis() + 60000, + }, + ); + } + + // Allowed IP. + let response = authenticate_live_mode( + TraefikHeaders { + ip: "172.16.0.1".to_string(), + }, + config.clone(), + health_status.clone(), + ipv4_data.clone(), + ) + .await; + assert_eq!(200, response.status()); + + // Forbidden IP. + let response = authenticate_live_mode( + TraefikHeaders { + ip: "172.16.0.2".to_string(), + }, + config.clone(), + health_status.clone(), + ipv4_data.clone(), + ) + .await; + assert_eq!(403, response.status()); +} + +#[test] +async fn test_authenticate_live_mode_from_api_allowed() { + // Set up test data. + let health_status = Data::new(Arc::new(Mutex::new(HealthStatus { + live_status: true, + stream_status: true, + }))); + let ipv4_data = Data::new(Arc::new(Mutex::new(IpLookupTable::< + Ipv4Addr, + CacheAttributes, + >::new()))); + // Add an expired (blocked) entry to the cache. + if let Ok(mut ipv4_table) = ipv4_data.lock() { + ipv4_table.insert( + Ipv4Addr::from_str("172.16.0.1").unwrap(), + 32, + CacheAttributes { + allowed: false, + expiration_time: chrono::Utc::now().timestamp_millis() - 10000, + }, + ); + } + + let api_key = "my_api_key"; + let ip = "172.16.0.1"; + + // Simulate an allowed IP address. + let mock_response = "null"; + let mut server = mockito::Server::new(); + let mock_server = server + .mock("GET", "/v1/decisions") + .match_header("X-Api-Key", api_key) + .match_query(format!("ip={}&type=ban", ip).as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(mock_response) + .create_async() + .await; + + let config = Data::new(Config { + crowdsec_live_url: server.url() + "/v1/decisions", + crowdsec_stream_url: "".to_string(), + crowdsec_api_key: api_key.to_string(), + crowdsec_mode: CrowdSecMode::Live, + crowdsec_cache_ttl: 60000, + stream_interval: 0, + port: 0, + }); + + // Allowed IP. + let response = authenticate_live_mode( + TraefikHeaders { ip: ip.to_string() }, + config.clone(), + health_status.clone(), + ipv4_data.clone(), + ) + .await; + assert_eq!(200, response.status()); + if let Ok(ipv4_table) = ipv4_data.lock() { + let res = ipv4_table.exact_match(Ipv4Addr::from_str("172.16.0.1").unwrap(), 32); + assert!(res.is_some()); + assert!(res.unwrap().allowed); + } + + // Clean up the mock server. + mock_server.assert(); + + // Add an expired (allowed) entry to the cache. + if let Ok(mut ipv4_table) = ipv4_data.lock() { + ipv4_table.insert( + Ipv4Addr::from_str("172.16.0.2").unwrap(), + 32, + CacheAttributes { + allowed: true, + expiration_time: chrono::Utc::now().timestamp_millis() - 10000, + }, + ); + } + + let ip = "172.16.0.2"; + + // Simulate a forbidden IP address. + let mock_response = serde_json::json!([ + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Ip", + "type": "ban", + "value": ip + } + ]); + let mock_server = server + .mock("GET", "/v1/decisions") + .match_header("X-Api-Key", api_key) + .match_query(format!("ip={}&type=ban", ip).as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(mock_response.to_string()) + .create_async() + .await; + + let config = Data::new(Config { + crowdsec_live_url: server.url() + "/v1/decisions", + crowdsec_stream_url: "".to_string(), + crowdsec_api_key: api_key.to_string(), + crowdsec_mode: CrowdSecMode::Live, + crowdsec_cache_ttl: 60000, + stream_interval: 0, + port: 0, + }); + + // Blocked IP. + let response = authenticate_live_mode( + TraefikHeaders { ip: ip.to_string() }, + config.clone(), + health_status.clone(), + ipv4_data.clone(), + ) + .await; + assert_eq!(403, response.status()); + if let Ok(ipv4_table) = ipv4_data.lock() { + let res = ipv4_table.exact_match(Ipv4Addr::from_str("172.16.0.2").unwrap(), 32); + assert!(res.is_some()); + assert!(!res.unwrap().allowed); + } + + // Clean up the mock server. + mock_server.assert(); +} + +#[test] +async fn test_authenticate_none_mode() { + // Set up test data. + let health_status = Data::new(Arc::new(Mutex::new(HealthStatus { + live_status: true, + stream_status: true, + }))); + let api_key = "my_api_key"; + let ip = "172.16.0.1"; + + // Simulate an allowed IP address. + let mock_response = "null"; + let mut server = mockito::Server::new(); + let mock_server = server + .mock("GET", "/v1/decisions") + .match_header("X-Api-Key", api_key) + .match_query(format!("ip={}&type=ban", ip).as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(mock_response) + .create_async() + .await; + + let config = Data::new(Config { + crowdsec_live_url: server.url() + "/v1/decisions", + crowdsec_stream_url: "".to_string(), + crowdsec_api_key: api_key.to_string(), + crowdsec_mode: CrowdSecMode::Live, + crowdsec_cache_ttl: 60000, + stream_interval: 0, + port: 0, + }); + + // Allowed IP. + let response = authenticate_none_mode( + TraefikHeaders { ip: ip.to_string() }, + config.clone(), + health_status.clone(), + ) + .await; + assert_eq!(200, response.status()); + + // Clean up the mock server. + mock_server.assert(); + + // Simulate a forbidden IP address. + let mock_response = serde_json::json!([ + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Ip", + "type": "ban", + "value": ip + } + ]); + let mock_server = server + .mock("GET", "/v1/decisions") + .match_header("X-Api-Key", api_key) + .match_query(format!("ip={}&type=ban", ip).as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(mock_response.to_string()) + .create_async() + .await; + + let config = Data::new(Config { + crowdsec_live_url: server.url() + "/v1/decisions", + crowdsec_stream_url: "".to_string(), + crowdsec_api_key: api_key.to_string(), + crowdsec_mode: CrowdSecMode::Live, + crowdsec_cache_ttl: 60000, + stream_interval: 0, + port: 0, + }); + + // Blocked IP. + let response = authenticate_none_mode( + TraefikHeaders { ip: ip.to_string() }, + config.clone(), + health_status.clone(), + ) + .await; + assert_eq!(403, response.status()); + + // Clean up the mock server. + mock_server.assert(); +} diff --git a/src/crowdsec.rs b/src/crowdsec.rs index 704caf3..25b1f32 100644 --- a/src/crowdsec.rs +++ b/src/crowdsec.rs @@ -144,103 +144,122 @@ async fn get_decisions_stream( }) } -/// The main function. -/// Updates the IP lookup tables with the new and deleted decisions at a regular interval. +/// Call the CrowdSec decisions stream API to get the new and deleted decisions. /// # Arguments /// * `config` - The configuration. /// * `health_status` - The health status. -/// * `ipv4_table` - The IPv4 lookup table. -/// * `ipv6_table` - The IPv6 lookup table. -pub async fn stream( +/// * `ipv4_table` - The IPv4 table. +/// * `ipv6_table` - The IPv6 table. +/// * `startup` - If `true`, the API will return all the decisions, otherwise it will only return the new decisions. +async fn update_decisions( config: Config, health_status: Arc>, ipv4_table: Arc>>, ipv6_table: Arc>>, + startup: &mut bool, ) { - let mut startup: bool = true; - let mut interval = time::interval(Duration::from_secs(config.stream_interval)); - loop { - interval.tick().await; - - // let mut ipv4_table_tmp = ipv4_table.lock().unwrap(); - // warn!("Initial insert"); - // ipv4_table_tmp.insert(Ipv4Addr::new(1, 2, 3, 4), 32, true); - // continue; + match get_decisions_stream( + &config.crowdsec_stream_url, + &config.crowdsec_api_key, + *startup, + ) + .await + { + Ok(stream) => { + set_health_status(health_status.clone(), true); + if stream.new.is_none() && stream.deleted.is_none() { + return; + } + info!("Decisions stream: {:?}", stream); - match get_decisions_stream( - &config.crowdsec_stream_url, - &config.crowdsec_api_key, - startup, - ) - .await - { - Ok(stream) => { - set_health_status(health_status.clone(), true); - if stream.new.is_none() && stream.deleted.is_none() { - continue; - } - info!("Decisions stream: {:?}", stream); - - if let Some(ref new) = stream.new { - for decision in new { - let range = get_ip_and_subnet(&decision.value); - match range { - Some(range) => match range.ipv4 { - Some(ipv4) => { - if let Ok(mut table) = ipv4_table.lock() { + if let Some(ref new) = stream.new { + for decision in new { + let range = get_ip_and_subnet(&decision.value); + match range { + Some(range) => match range.ipv4 { + Some(ipv4) => { + if let Ok(mut table) = ipv4_table.lock() { + table.insert( + ipv4, + range.subnet.unwrap_or(32), + CacheAttributes::new(false, 0), + ); + } + } + None => match range.ipv6 { + Some(ipv6) => { + if let Ok(mut table) = ipv6_table.lock() { table.insert( - ipv4, - range.subnet.unwrap_or(32), + ipv6, + range.subnet.unwrap_or(128), CacheAttributes::new(false, 0), ); } } - None => match range.ipv6 { - Some(ipv6) => { - if let Ok(mut table) = ipv6_table.lock() { - table.insert( - ipv6, - range.subnet.unwrap_or(128), - CacheAttributes::new(false, 0), - ); - } - } - None => warn!("Invalid IP (in new): {:?}", decision.value), - }, + None => warn!("Invalid IP (in new): {:?}", decision.value), }, - None => warn!("Invalid IP (in new): {:?}", decision.value), - } + }, + None => warn!("Invalid IP (in new): {:?}", decision.value), } } - if let Some(ref deleted) = stream.deleted { - for decision in deleted { - let range = get_ip_and_subnet(&decision.value); - match range { - Some(range) => match range.ipv4 { - Some(ipv4) => { - if let Ok(mut table) = ipv4_table.lock() { - table.remove(ipv4, range.subnet.unwrap_or(32)); - } + } + if let Some(ref deleted) = stream.deleted { + for decision in deleted { + let range = get_ip_and_subnet(&decision.value); + match range { + Some(range) => match range.ipv4 { + Some(ipv4) => { + if let Ok(mut table) = ipv4_table.lock() { + table.remove(ipv4, range.subnet.unwrap_or(32)); } - None => match range.ipv6 { - Some(ipv6) => { - if let Ok(mut table) = ipv6_table.lock() { - table.remove(ipv6, range.subnet.unwrap_or(128)); - } + } + None => match range.ipv6 { + Some(ipv6) => { + if let Ok(mut table) = ipv6_table.lock() { + table.remove(ipv6, range.subnet.unwrap_or(128)); } - None => warn!("Invalid IP (in deleted): {:?}", decision.value), - }, + } + None => warn!("Invalid IP (in deleted): {:?}", decision.value), }, - None => warn!("Invalid IP (in deleted): {:?}", decision.value), - } + }, + None => warn!("Invalid IP (in deleted): {:?}", decision.value), } } - startup = false; - } - Err(err) => { - error!("Could not call API. Error: {}", err); - set_health_status(health_status.clone(), false); } + *startup = false; + } + Err(err) => { + error!("Could not call API. Error: {}", err); + set_health_status(health_status.clone(), false); } } } + +/// The main function. +/// Updates the IP lookup tables with the new and deleted decisions at a regular interval. +/// # Arguments +/// * `config` - The configuration. +/// * `health_status` - The health status. +/// * `ipv4_table` - The IPv4 lookup table. +/// * `ipv6_table` - The IPv6 lookup table. +pub async fn stream_loop_thread( + config: Config, + health_status: Arc>, + ipv4_table: Arc>>, + ipv6_table: Arc>>, +) { + let mut startup: bool = true; + let mut interval = time::interval(Duration::from_secs(config.stream_interval)); + loop { + interval.tick().await; + + update_decisions( + config.clone(), + health_status.clone(), + ipv4_table.clone(), + ipv6_table.clone(), + &mut startup, + ) + .await; + } +} diff --git a/src/crowdsec/tests.rs b/src/crowdsec/tests.rs index 5701299..4ca6df0 100644 --- a/src/crowdsec/tests.rs +++ b/src/crowdsec/tests.rs @@ -1,6 +1,76 @@ use super::*; -#[tokio::test] +use crate::config::CrowdSecMode; + +const MOCK_RESPONSE_1: &str = r#"{ + "new": [ + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Ip", + "type": "ban", + "value": "1.2.3.4" + }, + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Range", + "type": "ban", + "value": "1.1.0.0/22" + } + ], + "deleted": [ + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Ip", + "type": "ban", + "value": "3.3.3.3" + }, + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Ip", + "type": "ban", + "value": "4.4.4.4" + } +] +}"#; + +const MOCK_RESPONSE_2: &str = r#"{ + "new": [ + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Range", + "type": "ban", + "value": "2.2.0.0/16" + } + ], + "deleted": [ + { + "duration": "33h6m18.03174611s", + "id": 1, + "origin": "CAPI", + "scenario": "abc", + "scope": "Range", + "type": "ban", + "value": "1.1.0.0/22" + } +] +}"#; + +#[test] async fn test_get_decision_banned_ip() { let api_key = "my_api_key"; let ip = "1.2.3.4"; @@ -40,7 +110,7 @@ async fn test_get_decision_banned_ip() { mock_server.assert(); } -#[tokio::test] +#[test] async fn test_get_decision_unbanned_ip() { let api_key = "my_api_key"; let ip = "1.2.3.4"; @@ -67,7 +137,7 @@ async fn test_get_decision_unbanned_ip() { mock_server.assert(); } -#[tokio::test] +#[test] async fn test_get_decision_bad_response() { let api_key = "my_api_key"; let ip = "1.2.3.4"; @@ -96,54 +166,12 @@ async fn test_get_decision_bad_response() { mock_server.assert(); } -#[tokio::test] +#[test] async fn test_get_decisions_stream_startup() { let api_key = "my_api_key"; let startup = true; // Simulate a stream of decisions. - let mock_response = serde_json::json!({ - "new": [ - { - "duration": "33h6m18.03174611s", - "id": 1, - "origin": "CAPI", - "scenario": "abc", - "scope": "Ip", - "type": "ban", - "value": "1.2.3.4" - }, - { - "duration": "33h6m18.03174611s", - "id": 1, - "origin": "CAPI", - "scenario": "abc", - "scope": "Ip", - "type": "ban", - "value": "1.2.3.5" - } - ], - "deleted": [ - { - "duration": "33h6m18.03174611s", - "id": 1, - "origin": "CAPI", - "scenario": "abc", - "scope": "Ip", - "type": "ban", - "value": "1.1.1.1" - }, - { - "duration": "33h6m18.03174611s", - "id": 1, - "origin": "CAPI", - "scenario": "abc", - "scope": "Ip", - "type": "ban", - "value": "1.1.1.2" - } - ] - }); let mut server = mockito::Server::new(); let mock_server = server .mock("GET", "/v1/decisions/stream") @@ -151,7 +179,7 @@ async fn test_get_decisions_stream_startup() { .match_query(format!("startup=true&scope=Ip%2CRange").as_str()) .with_status(200) .with_header("content-type", "application/json") - .with_body(mock_response.to_string()) + .with_body(MOCK_RESPONSE_1) .create(); // Get the decision stream. @@ -163,11 +191,104 @@ async fn test_get_decisions_stream_startup() { let new = stream.new.unwrap(); assert_eq!(2, new.len()); assert_eq!("1.2.3.4", new[0].value); - assert_eq!("1.2.3.5", new[1].value); + assert_eq!("1.1.0.0/22", new[1].value); let deleted = stream.deleted.unwrap(); assert_eq!(2, deleted.len()); - assert_eq!("1.1.1.1", deleted[0].value); - assert_eq!("1.1.1.2", deleted[1].value); + assert_eq!("3.3.3.3", deleted[0].value); + assert_eq!("4.4.4.4", deleted[1].value); + + // Clean up the mock server. + mock_server.assert(); +} + +#[test] +async fn test_update_decisions() { + let api_key = "my_api_key"; + + // Simulate a stream of decisions. + let mut server = mockito::Server::new(); + let mock_server = server + .mock("GET", "/v1/decisions/stream") + .match_header("X-Api-Key", api_key) + .match_query(format!("startup=true&scope=Ip%2CRange").as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(MOCK_RESPONSE_1) + .create(); + + // Set up test data + let config = Config { + crowdsec_live_url: "".to_string(), + crowdsec_stream_url: server.url() + "/v1/decisions/stream", + crowdsec_api_key: api_key.to_string(), + crowdsec_mode: CrowdSecMode::Stream, + crowdsec_cache_ttl: 0, + stream_interval: 0, + port: 0, + }; + let health_status = Arc::new(Mutex::new(HealthStatus::new())); + let ipv4_table = Arc::new(Mutex::new(IpLookupTable::new())); + let ipv6_table = Arc::new(Mutex::new(IpLookupTable::new())); + let mut startup = true; + + update_decisions( + config.clone(), + health_status.clone(), + ipv4_table.clone(), + ipv6_table.clone(), + &mut startup, + ) + .await; + + // Check the result. + if let Ok(ipv4_table) = ipv4_table.lock() { + assert_eq!(2, ipv4_table.len()); + assert!(ipv4_table + .exact_match(Ipv4Addr::new(1, 2, 3, 4), 32) + .is_some()); + assert!(ipv4_table + .exact_match(Ipv4Addr::new(1, 1, 0, 0), 22) + .is_some()); + } else { + panic!("Expected an IPv4 table."); + } + + // Clean up the mock server. + mock_server.assert(); + + let mock_server = server + .mock("GET", "/v1/decisions/stream") + .match_header("X-Api-Key", api_key) + .match_query(format!("startup=false&scope=Ip%2CRange").as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(MOCK_RESPONSE_2) + .create(); + + update_decisions( + config.clone(), + health_status.clone(), + ipv4_table.clone(), + ipv6_table.clone(), + &mut startup, + ) + .await; + + // Check the result. + if let Ok(ipv4_table) = ipv4_table.lock() { + assert_eq!(2, ipv4_table.len()); + assert!(ipv4_table + .exact_match(Ipv4Addr::new(1, 2, 3, 4), 32) + .is_some()); + assert!(ipv4_table + .exact_match(Ipv4Addr::new(1, 1, 0, 0), 22) + .is_none()); + assert!(ipv4_table + .exact_match(Ipv4Addr::new(2, 2, 0, 0), 16) + .is_some()); + } else { + panic!("Expected an IPv4 table."); + } // Clean up the mock server. mock_server.assert(); diff --git a/src/errors.rs b/src/errors.rs index b79d149..4b19f29 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -9,3 +9,8 @@ custom_error! { ResponseParsingFailed{error: String} = "CrowdSec response parsing failed. Error: {error}", UrlParsingFailed{error: ParseError} = "CrowdSec url parsing failed. Error: {error}", } + +custom_error! { + pub TraefikError + BadHeaders = "Bad Forward-Request headers.", +} diff --git a/src/main.rs b/src/main.rs index f8ad00a..cf9d2bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,7 +44,7 @@ async fn main() -> io::Result<()> { info!("Starting CrowdSec stream update."); // Update the IP tables from CrowdSec stream. actix_rt::spawn(async move { - crowdsec::stream( + crowdsec::stream_loop_thread( config_clone, health_status_clone, ipv4_table_clone,