Skip to content

Commit

Permalink
fix(safe_browsing): improve update logic and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanccn committed Nov 18, 2024
1 parent 75b8c58 commit 2991bf9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
37 changes: 23 additions & 14 deletions src/safe_browsing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::{
use tokio::sync::RwLock;

use crate::reqwest_client::HTTP;
use crate::utils::Pluralize as _;

mod canonicalize;
mod models;
Expand Down Expand Up @@ -87,6 +88,7 @@ impl SafeBrowsing {
.json(&request)
.send()
.await?
.error_for_status()?
.json()
.await?;

Expand All @@ -99,21 +101,25 @@ impl SafeBrowsing {
.map(|s| s.prefixes.clone())
.unwrap_or_default();

for entry_set in &list_update.removals {
if let Some(raw_indices) = &entry_set.raw_indices {
for (idx_idx, idx) in raw_indices.indices.iter().enumerate() {
current_prefixes.remove(idx - idx_idx);
for removal in &list_update.removals {
let mut reversed_indices = removal.raw_indices.indices.clone();
reversed_indices.sort_unstable();

for idx in reversed_indices.into_iter().rev() {
if idx < current_prefixes.len() {
current_prefixes.remove(idx);
}
}
}

for entry_set in &list_update.additions {
if let Some(raw_hashes) = &entry_set.raw_hashes {
let hashes = BASE64.decode(&raw_hashes.raw_hashes)?;
for addition in &list_update.additions {
let hashes = BASE64.decode(&addition.raw_hashes.raw_hashes)?;

current_prefixes
.extend(hashes.chunks(raw_hashes.prefix_size).map(|c| c.to_vec()));
}
current_prefixes.extend(
hashes
.chunks(addition.raw_hashes.prefix_size)
.map(|c| c.to_vec()),
);
}

current_prefixes.sort_unstable();
Expand All @@ -134,7 +140,7 @@ impl SafeBrowsing {
list_update.checksum.sha256
);

self.states.write().await.clear();
self.states.write().await.remove(&list_update.threat_type);
self.update().await?;

return Ok(());
Expand Down Expand Up @@ -262,18 +268,21 @@ impl SafeBrowsing {
.collect::<Vec<_>>();

tracing::trace!(
"Scanned {} URLs in {:.2}ms (prefixes matched) => {} matches",
"Scanned {} {} in {:.2}ms (prefixes matched) => {} {}",
urls.len(),
"URL".pluralize(urls.len()),
bench_start.elapsed().as_millis(),
matches.len()
matches.len(),
"match".pluralize_alternate(matches.len(), "matches")
);

return Ok(matches);
}

tracing::trace!(
"Scanned {} URLs in {:.2}ms (no prefixes matched) => no matches",
"Scanned {} {} in {:.2}ms (no prefixes matched) => no matches",
urls.len(),
"URL".pluralize(urls.len()),
bench_start.elapsed().as_millis(),
);

Expand Down
15 changes: 10 additions & 5 deletions src/safe_browsing/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ pub struct ListUpdateResponse {
pub checksum: ListUpdateChecksum,

#[serde(default)]
pub additions: Vec<ThreatEntrySet>,
pub additions: Vec<ListUpdateAdditions>,
#[serde(default)]
pub removals: Vec<ThreatEntrySet>,
pub removals: Vec<ListUpdateRemovals>,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -73,9 +73,14 @@ pub struct ListUpdateChecksum {

#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThreatEntrySet {
pub raw_hashes: Option<RawHashes>,
pub raw_indices: Option<RawIndices>,
pub struct ListUpdateAdditions {
pub raw_hashes: RawHashes,
}

#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListUpdateRemovals {
pub raw_indices: RawIndices,
}

#[derive(Debug, Clone, Deserialize)]
Expand Down
23 changes: 11 additions & 12 deletions src/utils/error_handling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@ pub enum ErrorOrPanic<'a> {
Panic(&'a Option<String>),
}

impl ErrorOrPanic<'_> {
/// Return whether `self` is a panic or an error.
fn type_string(&self) -> String {
match self {
Self::Panic(_) => "panic".to_owned(),
Self::Error(_) => "error".to_owned(),
}
}
}

/// A wrapped type around errors or panics encapsulated in [`ErrorOrPanic`] that includes context from Poise and a randomly generated `error_id`.
#[derive(Debug)]
pub struct ValfiskError<'a> {
Expand Down Expand Up @@ -66,9 +56,18 @@ impl ValfiskError<'_> {
}

/// Log the error to the console.
#[tracing::instrument(skip(self), fields(id = self.error_id, r#type = self.error_or_panic.type_string(), command = self.ctx.invocation_string(), channel = self.ctx.channel_id().get(), author = self.ctx.author().id.get()))]
#[tracing::instrument(skip(self))]
pub fn handle_log(&self) {
error!("{:?}", self.error_or_panic);
error!(
{
id = self.error_id,
command = self.ctx.invocation_string(),
channel = self.ctx.channel_id().get(),
author = self.ctx.author().id.get()
},
"{:?}",
self.error_or_panic,
);
}

/// Reply to the interaction with an embed informing the user of an error, containing the randomly generated error ID.
Expand Down

0 comments on commit 2991bf9

Please sign in to comment.