From 5ac65546ba70dddfe9af3986d8d6d8cb4f348c6e Mon Sep 17 00:00:00 2001 From: Ryan Cao <70191398+ryanccn@users.noreply.github.com> Date: Sat, 16 Nov 2024 19:09:51 +0800 Subject: [PATCH] fix(safe_browsing): populate database on startup --- src/main.rs | 4 ++++ src/safe_browsing/mod.rs | 35 +++++++++++++++-------------------- src/safe_browsing/models.rs | 10 +++++++--- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/main.rs b/src/main.rs index 10ce956..5628d02 100644 --- a/src/main.rs +++ b/src/main.rs @@ -275,6 +275,10 @@ async fn main() -> Result<()> { let data = Arc::new(Data::new()?); + if let Some(safe_browsing) = &data.safe_browsing { + safe_browsing.update().await?; + } + let mut client = serenity::Client::builder(&CONFIG.discord_token, serenity::GatewayIntents::all()) .framework(Framework::new(FrameworkOptions { diff --git a/src/safe_browsing/mod.rs b/src/safe_browsing/mod.rs index 8fc1e30..4d05b6a 100644 --- a/src/safe_browsing/mod.rs +++ b/src/safe_browsing/mod.rs @@ -99,29 +99,25 @@ impl SafeBrowsing { .map(|s| s.prefixes.clone()) .unwrap_or_default(); - if let Some(removals) = list_update.removals { - for entry_set in removals { - if let Some(raw_indices) = entry_set.raw_indices { - for index in raw_indices.indices { - if (index as usize) < current_prefixes.len() { - current_prefixes.remove(index as usize); - } + for entry_set in &list_update.removals { + if let Some(raw_indices) = &entry_set.raw_indices { + for index in &raw_indices.indices { + if (*index as usize) < current_prefixes.len() { + current_prefixes.remove(*index as usize); } } } } - if let Some(additions) = list_update.additions { - for entry_set in additions { - if let Some(raw_hashes) = entry_set.raw_hashes { - let hashes = BASE64.decode(raw_hashes.raw_hashes)?; + for entry_set in &list_update.additions { + if let Some(raw_hashes) = &entry_set.raw_hashes { + let hashes = BASE64.decode(&raw_hashes.raw_hashes)?; - current_prefixes.extend( - hashes - .chunks(raw_hashes.prefix_size as usize) - .map(|c| c.to_vec()), - ); - } + current_prefixes.extend( + hashes + .chunks(raw_hashes.prefix_size as usize) + .map(|c| c.to_vec()), + ); } } @@ -137,7 +133,7 @@ impl SafeBrowsing { if checksum != list_update.checksum.sha256 { tracing::error!( - "list {:?} checksum has drifted, correcting (actual: {:?}, expected: {:?})", + "List {:?} checksum has drifted, resetting (actual: {:?}, expected: {:?})", list_update.threat_type, checksum, list_update.checksum.sha256 @@ -165,7 +161,7 @@ impl SafeBrowsing { .await .values() .map(|v| v.prefixes.len()) - .sum::() + .sum::(), ); Ok(()) @@ -255,7 +251,6 @@ impl SafeBrowsing { let matches = response .matches - .unwrap_or_default() .into_par_iter() .filter_map(|m| { if let Ok(raw_threat_hash) = BASE64.decode(&m.threat.hash) { diff --git a/src/safe_browsing/models.rs b/src/safe_browsing/models.rs index df3ff78..f6b487e 100644 --- a/src/safe_browsing/models.rs +++ b/src/safe_browsing/models.rs @@ -57,9 +57,12 @@ pub struct ThreatListUpdateResponse { pub struct ListUpdateResponse { pub threat_type: String, pub new_client_state: String, - pub additions: Option>, - pub removals: Option>, pub checksum: ListUpdateChecksum, + + #[serde(default)] + pub additions: Vec, + #[serde(default)] + pub removals: Vec, } #[derive(Debug, Clone, Deserialize)] @@ -114,7 +117,8 @@ pub struct ThreatEntry { #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FindFullHashesResponse { - pub matches: Option>, + #[serde(default)] + pub matches: Vec, } #[derive(Debug, Clone, Deserialize)]