Skip to content

Commit

Permalink
fix(safe_browsing): populate database on startup
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanccn committed Nov 16, 2024
1 parent 69071b9 commit 5ac6554
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
4 changes: 4 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ async fn main() -> Result<()> {

let data = Arc::new(Data::new()?);

if let Some(safe_browsing) = &data.safe_browsing {
safe_browsing.update().await?;
}

let mut client =
serenity::Client::builder(&CONFIG.discord_token, serenity::GatewayIntents::all())
.framework(Framework::new(FrameworkOptions {
Expand Down
35 changes: 15 additions & 20 deletions src/safe_browsing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,25 @@ impl SafeBrowsing {
.map(|s| s.prefixes.clone())
.unwrap_or_default();

if let Some(removals) = list_update.removals {
for entry_set in removals {
if let Some(raw_indices) = entry_set.raw_indices {
for index in raw_indices.indices {
if (index as usize) < current_prefixes.len() {
current_prefixes.remove(index as usize);
}
for entry_set in &list_update.removals {
if let Some(raw_indices) = &entry_set.raw_indices {
for index in &raw_indices.indices {
if (*index as usize) < current_prefixes.len() {
current_prefixes.remove(*index as usize);
}
}
}
}

if let Some(additions) = list_update.additions {
for entry_set in additions {
if let Some(raw_hashes) = entry_set.raw_hashes {
let hashes = BASE64.decode(raw_hashes.raw_hashes)?;
for entry_set in &list_update.additions {
if let Some(raw_hashes) = &entry_set.raw_hashes {
let hashes = BASE64.decode(&raw_hashes.raw_hashes)?;

current_prefixes.extend(
hashes
.chunks(raw_hashes.prefix_size as usize)
.map(|c| c.to_vec()),
);
}
current_prefixes.extend(
hashes
.chunks(raw_hashes.prefix_size as usize)
.map(|c| c.to_vec()),
);
}
}

Expand All @@ -137,7 +133,7 @@ impl SafeBrowsing {

if checksum != list_update.checksum.sha256 {
tracing::error!(
"list {:?} checksum has drifted, correcting (actual: {:?}, expected: {:?})",
"List {:?} checksum has drifted, resetting (actual: {:?}, expected: {:?})",
list_update.threat_type,
checksum,
list_update.checksum.sha256
Expand Down Expand Up @@ -165,7 +161,7 @@ impl SafeBrowsing {
.await
.values()
.map(|v| v.prefixes.len())
.sum::<usize>()
.sum::<usize>(),
);

Ok(())
Expand Down Expand Up @@ -255,7 +251,6 @@ impl SafeBrowsing {

let matches = response
.matches
.unwrap_or_default()
.into_par_iter()
.filter_map(|m| {
if let Ok(raw_threat_hash) = BASE64.decode(&m.threat.hash) {
Expand Down
10 changes: 7 additions & 3 deletions src/safe_browsing/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ pub struct ThreatListUpdateResponse {
pub struct ListUpdateResponse {
pub threat_type: String,
pub new_client_state: String,
pub additions: Option<Vec<ThreatEntrySet>>,
pub removals: Option<Vec<ThreatEntrySet>>,
pub checksum: ListUpdateChecksum,

#[serde(default)]
pub additions: Vec<ThreatEntrySet>,
#[serde(default)]
pub removals: Vec<ThreatEntrySet>,
}

#[derive(Debug, Clone, Deserialize)]
Expand Down Expand Up @@ -114,7 +117,8 @@ pub struct ThreatEntry {
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FindFullHashesResponse {
pub matches: Option<Vec<ThreatMatch>>,
#[serde(default)]
pub matches: Vec<ThreatMatch>,
}

#[derive(Debug, Clone, Deserialize)]
Expand Down

0 comments on commit 5ac6554

Please sign in to comment.