From 2991bf99b2dd7d489abe42ac25fdcf672ed9a872 Mon Sep 17 00:00:00 2001 From: Ryan Cao <70191398+ryanccn@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:54:43 +0800 Subject: [PATCH] fix(safe_browsing): improve update logic and logging --- src/safe_browsing/mod.rs | 37 +++++++++++++++++++++++-------------- src/safe_browsing/models.rs | 15 ++++++++++----- src/utils/error_handling.rs | 23 +++++++++++------------ 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/src/safe_browsing/mod.rs b/src/safe_browsing/mod.rs index aadfe39..18775bf 100644 --- a/src/safe_browsing/mod.rs +++ b/src/safe_browsing/mod.rs @@ -16,6 +16,7 @@ use std::{ use tokio::sync::RwLock; use crate::reqwest_client::HTTP; +use crate::utils::Pluralize as _; mod canonicalize; mod models; @@ -87,6 +88,7 @@ impl SafeBrowsing { .json(&request) .send() .await? + .error_for_status()? .json() .await?; @@ -99,21 +101,25 @@ impl SafeBrowsing { .map(|s| s.prefixes.clone()) .unwrap_or_default(); - for entry_set in &list_update.removals { - if let Some(raw_indices) = &entry_set.raw_indices { - for (idx_idx, idx) in raw_indices.indices.iter().enumerate() { - current_prefixes.remove(idx - idx_idx); + for removal in &list_update.removals { + let mut reversed_indices = removal.raw_indices.indices.clone(); + reversed_indices.sort_unstable(); + + for idx in reversed_indices.into_iter().rev() { + if idx < current_prefixes.len() { + current_prefixes.remove(idx); } } } - for entry_set in &list_update.additions { - if let Some(raw_hashes) = &entry_set.raw_hashes { - let hashes = BASE64.decode(&raw_hashes.raw_hashes)?; + for addition in &list_update.additions { + let hashes = BASE64.decode(&addition.raw_hashes.raw_hashes)?; - current_prefixes - .extend(hashes.chunks(raw_hashes.prefix_size).map(|c| c.to_vec())); - } + current_prefixes.extend( + hashes + .chunks(addition.raw_hashes.prefix_size) + .map(|c| c.to_vec()), + ); } current_prefixes.sort_unstable(); @@ -134,7 +140,7 @@ impl SafeBrowsing { list_update.checksum.sha256 ); - self.states.write().await.clear(); + self.states.write().await.remove(&list_update.threat_type); self.update().await?; return Ok(()); @@ -262,18 +268,21 @@ impl SafeBrowsing { .collect::>(); tracing::trace!( - "Scanned {} URLs in {:.2}ms (prefixes matched) => {} matches", + "Scanned {} {} in {:.2}ms (prefixes matched) => {} {}", urls.len(), + "URL".pluralize(urls.len()), bench_start.elapsed().as_millis(), - matches.len() + matches.len(), + "match".pluralize_alternate(matches.len(), "matches") ); return Ok(matches); } tracing::trace!( - "Scanned {} URLs in {:.2}ms (no prefixes matched) => no matches", + "Scanned {} {} in {:.2}ms (no prefixes matched) => no matches", urls.len(), + "URL".pluralize(urls.len()), bench_start.elapsed().as_millis(), ); diff --git a/src/safe_browsing/models.rs b/src/safe_browsing/models.rs index d45b2c3..be76630 100644 --- a/src/safe_browsing/models.rs +++ b/src/safe_browsing/models.rs @@ -60,9 +60,9 @@ pub struct ListUpdateResponse { pub checksum: ListUpdateChecksum, #[serde(default)] - pub additions: Vec, + pub additions: Vec, #[serde(default)] - pub removals: Vec, + pub removals: Vec, } #[derive(Debug, Clone, Deserialize)] @@ -73,9 +73,14 @@ pub struct ListUpdateChecksum { #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct ThreatEntrySet { - pub raw_hashes: Option, - pub raw_indices: Option, +pub struct ListUpdateAdditions { + pub raw_hashes: RawHashes, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListUpdateRemovals { + pub raw_indices: RawIndices, } #[derive(Debug, Clone, Deserialize)] diff --git a/src/utils/error_handling.rs b/src/utils/error_handling.rs index 3a11e67..fad9f98 100644 --- a/src/utils/error_handling.rs +++ b/src/utils/error_handling.rs @@ -23,16 +23,6 @@ pub enum ErrorOrPanic<'a> { Panic(&'a Option), } -impl ErrorOrPanic<'_> { - /// Return whether `self` is a panic or an error. - fn type_string(&self) -> String { - match self { - Self::Panic(_) => "panic".to_owned(), - Self::Error(_) => "error".to_owned(), - } - } -} - /// A wrapped type around errors or panics encapsulated in [`ErrorOrPanic`] that includes context from Poise and a randomly generated `error_id`. #[derive(Debug)] pub struct ValfiskError<'a> { @@ -66,9 +56,18 @@ impl ValfiskError<'_> { } /// Log the error to the console. - #[tracing::instrument(skip(self), fields(id = self.error_id, r#type = self.error_or_panic.type_string(), command = self.ctx.invocation_string(), channel = self.ctx.channel_id().get(), author = self.ctx.author().id.get()))] + #[tracing::instrument(skip(self))] pub fn handle_log(&self) { - error!("{:?}", self.error_or_panic); + error!( + { + id = self.error_id, + command = self.ctx.invocation_string(), + channel = self.ctx.channel_id().get(), + author = self.ctx.author().id.get() + }, + "{:?}", + self.error_or_panic, + ); } /// Reply to the interaction with an embed informing the user of an error, containing the randomly generated error ID.